Nikola Janjušević
Published 2021-06-05

Introduction to Julia by TV denoising

In this post, we implement color-image TV denoising in the Julia programming language.

  1. Total Variation Color Image Denoising
  2. Julia Implementation
    1. Julia Workflow
    2. The Sparse Way
    3. The Fast (Fourier Transform) Way
    4. Comparison

Total Variation Color Image Denoising

The TV-denoising model can be motivated by wanting to approximate a clean image as piecewise constant. As such, the desired image will have a sparse gradient, x\nabla \mathbf{x}. We can approximate the gradient to first order by a finite difference scheme, xDx\nabla \mathbf{x} \approx \mathbf{D}\mathbf{x}, with (Dx)[i]=x[i]x[i1](\mathbf{D}\mathbf{x})[i] = \mathbf{x}[i]-\mathbf{x}[i-1].

We then formulate the TV-denoising problem as

minimize x 12xy22+λDx1, \underset{\mathbf{x}}{\mathrm{minimize}\,}\, \tfrac{1}{2}\lVert \mathbf{x}-\mathbf{y} \rVert_2^2 + \lambda \lVert \mathbf{D}\mathbf{x} \rVert_1,

and which we can reformulate with a constraint equation to be amenable to ADMM,

minimize x 12xy22+λz1s.t. Dxz=0\begin{array}{rcl} \underset{\mathbf{x}}{\mathrm{minimize}\,}\,& \tfrac{1}{2}\lVert \mathbf{x}-\mathbf{y} \rVert_2^2 + \lambda \lVert \mathbf{z} \rVert_1 \\ \mathrm{s.t.}\, & \mathbf{D}\mathbf{x} - \mathbf{z} = \mathbf{0} \end{array}

We now derive our iterates in scaled form

xk+1argmin x (1/2)xy22+(ρ/2)Dxzk+uk220=(xk+1y)+ρD(Dxk+1zk+uk)xk+1=(I+ρDD)1(y+ρD(zkuk))\begin{array}{rcl} \mathbf{x}^{k+1} &\coloneqq& \underset{\mathbf{x}}{\mathrm{argmin}\,}\, (1/2)\lVert \mathbf{x}-\mathbf{y} \rVert_2^2 + (\rho/2)\lVert \mathbf{D}\mathbf{x} - \mathbf{z}^k + \mathbf{u}^k \rVert_2^2 \\ \mathbf{0} &=& (\mathbf{x}^{k+1}-\mathbf{y}) + \rho\mathbf{D}^\top (\mathbf{D}\mathbf{x}^{k+1} - \mathbf{z}^{k} + \mathbf{u}^k) \\ \mathbf{x}^{k+1} &=& (I + \rho \mathbf{D}^\top\mathbf{D})^{-1}(\mathbf{y} + \rho\mathbf{D}^\top(\mathbf{z}^k - \mathbf{u}^k)) \end{array}

The above x\mathbf{x} iterate involves solving a tridiagonal system, which is fast. As the LHS of the system is static, we can store its Cholesky factor for faster subsequent solves. However, if we further employ the TV operator with circular convolution, the system is diagonalized in the Fourier domain as D=FHΛF\mathbf{D} = \mathcal{F}^\mathrm{H} \Lambda \mathcal{F}. Our xx update then is then,

xk+1=FH(I+ρΛ2)1F(y+ρD(zkuk)) \mathbf{x}^{k+1} = \mathcal{F}^\mathrm{H} (I + \rho \lvert \Lambda \rvert^2)^{-1}\mathcal{F}(\mathbf{y} + \rho\mathbf{D}^\top(\mathbf{z}^k - \mathbf{u}^k))

In this case, we can precompute and store the diagonal Fourier coefficients, C=1/(1+ρΛ2){C = 1/(1 + \rho \lvert \Lambda \rvert^2)}. Our zz-update is an element-wise soft-thresholding[1]

zk+1argmin z λz1+(ρ/2)Dxk+1z+uk22=ST(Dxk+1+uk, λ/ρ)\begin{array}{rcl} \mathbf{z}^{k+1} &\coloneqq& \underset{\mathbf{z}}{\mathrm{argmin}\,}\, \lambda\lVert \mathbf{z} \rVert_1 + (\rho/2)\lVert \mathbf{D}\mathbf{x}^{k+1} - \mathbf{z} + \mathbf{u}^k \rVert_2^2 \\ &=& \mathrm{ST}\left(\mathbf{D}\mathbf{x}^{k+1} + \mathbf{u}^k,\, \lambda/\rho \right) \end{array}

Lastly, we have the scaled dual ascent, uk+1=uk+Dxk+1zk+1\mathbf{u}^{k+1} = \mathbf{u}^k + \mathbf{D}\mathbf{x}^{k+1} - \mathbf{z}^{k+1}.

In the case of 2D signals (i.e. an image), a simple option for an approximation to the gradient would be a finite difference scheme in both xx and yy directions[2],

Dx=[DxxDyx]=[dxxdyx],(dxx)[i,j]=x[i,j]x[i,j1](dyx)[i,j]=x[i,j]x[i1,j] \mathbf{D}\mathbf{x} = \begin{bmatrix} \mathbf{D}_x\mathbf{x} \\ \mathbf{D}_y\mathbf{x} \end{bmatrix} = \begin{bmatrix} \mathbf{d}_x \ast \mathbf{x} \\ \mathbf{d}_y \ast \mathbf{x} \end{bmatrix}, \quad \begin{matrix} (\mathbf{d}_x \mathbf{x})[i,j] = \mathbf{x}[i,j] - \mathbf{x}[i,j-1] \\ (\mathbf{d}_y \mathbf{x})[i,j] = \mathbf{x}[i,j] - \mathbf{x}[i-1,j] \end{matrix}

Our xx-update can still be computed in the Fourier domain via diagonalizing DD\mathbf{D}^\top\mathbf{D},

DD=DxDx+DyDy=FH(Λx2+Λy2)F\begin{array}{rcl} \mathbf{D}^\top\mathbf{D} &=& \mathbf{D}_x^\top\mathbf{D}_x + \mathbf{D}_y^\top \mathbf{D}_y\\ &=&\mathcal{F}^\mathrm{H} (\lvert \Lambda_x \rvert^2 + \lvert \Lambda_y \rvert^2 ) \mathcal{F} \end{array}

to yeild the update,

xk+1=FH(I+ρ(Λx2+Λy2))1F(y+ρD(zkuk)) \mathbf{x}^{k+1} = \mathcal{F}^\mathrm{H} (I + \rho (\lvert \Lambda_x \rvert^2 + \lvert \Lambda_y \rvert^2))^{-1}\mathcal{F}(\mathbf{y} + \rho\mathbf{D}^\top(\mathbf{z}^k - \mathbf{u}^k))

We can further extend this to RGB images (or any other multi-feature image) by separately taking xx and yy derivatives in each channel. For x=[xR;xG;xB]\mathbf{x} =[\mathbf{x}_R ; \mathbf{x}_G ; \mathbf{x}_B],

D3x=[DxxRDyxRDxxGDyxGDxxBDyxB]=(I3[Dx00Dy]undefinedD)x \mathbf{D}_3\mathbf{x} = \begin{bmatrix} \mathbf{D}_x \mathbf{x}_R \\ \mathbf{D}_y \mathbf{x}_R \\ \mathbf{D}_x \mathbf{x}_G \\ \mathbf{D}_y \mathbf{x}_G \\ \mathbf{D}_x \mathbf{x}_B \\ \mathbf{D}_y \mathbf{x}_B \end{bmatrix} = \Big( I_3 \otimes \underbrace{\begin{bmatrix} \mathbf{D}_x & 0 \\ 0 & \mathbf{D}_y \end{bmatrix}}_{\mathbf{D}} \Big) \mathbf{x}

Where \otimes is the Kronecker product. By the nature of Kronecker products, transposition is distributive and (AB)(CD)=ACBD(A \otimes B)(C \otimes D) = AC\otimes BD. Therefore,

D3D3=I3DD, \mathbf{D}_3^\top\mathbf{D}_3 = I_3 \otimes \mathbf{D}^\top\mathbf{D},

and our x-update involves dividing by Fourier domain factor CC (in (8)) channel-wise.

Julia Implementation

I've implemented a TV image denoising package, TVDenoise.jl. Below I walk through the source code in more detail and how I've learned to program in Julia.

Julia Workflow

Julia is most easily interfaced with via the REPL (read-eval-print-loop). There are 5 modes: Julian (default), Shell (;), Help (?), Package (]), Search (^r/^s). This is pretty similar to Matlab except you can't clear your workspace variables.

Packages

Package pkg can be installed via ] add pkg in the REPL. You can start using package pkg in two different ways:

Workflow

The Sparse Way

We can write scalar soft-thresholding function in one line:

ST(x,τ) = sign(x)*max(abs(x)-τ, 0)
ST(1,0.5)
0.5

To apply ST (or any scalar function) element-wise we use dot-syntax (and ' to turn the resulting column-vector into a row-vector for display purposes),

x = -1:0.2:1
y = ST.(x,0.2)'
1×11 adjoint(::Vector{Float64}) with eltype Float64:
 -0.8  -0.6  -0.4  -0.2  -0.0  0.0  0.0  0.2  0.4  0.6  0.8
⚠Warning⚠:
All variables are passed by reference, not value – if you augment a variable inside a function it will change outside of it too. Functions that mutate their inputs are post-fixed by ! by convention, ex. push!(v).

We can always implement the ADMM described above directly with sparse linear solves. First, we need to generate the sparse first-order derivative matrix.

function FDmat(N::Int)::SparseMatrixCSC
	spdiagm(0 => -1*ones(N), 1 => ones(N))[1:N-1,1:N];
end

Here we specify the input as an integer and output as a SparseMatrixCSC. By default, Julia functions return the return of their last statement. Therefore, the return keyword is not needed in this function definition (though it could be used). FDmat(N) returns DRN1×ND \in \mathbb R^{N-1\times N} and imposes no boundary conditions.

The 2D TV operator is formed by stacking the horizontal and vertical first-order derivative matrices on top of each other. These operators need to act on the vectorized image. Using the fact that Julia stores matrices column-major, you should be able to understand why our 2D and 3D functions are implemented as follows,

function FDmat(M::Int, N::Int)::SparseMatrixCSC
	# vertical derivative
	S = spdiagm(N-1, N, ones(N-1));
	T = FDmat(M);
	Dy = kron(S,T);
	# horizontal derivative
	S = FDmat(N);
	T = spdiagm(M-1, M, ones(M-1));
	Dx = kron(S,T);
	return [Dx; Dy];
end

function FDmat(M::Int, N::Int, C::Int)::SparseMatrixCSC
	kron(I(C),FDmat(M,N))
end

This definition of FDmat() shows off Julia's "multiple dispatch", which is basically function overloading.

Note:
Type decorating of function arguments is really only necessary for multiple dispatching, and doesn't affect performance.

With this, we can implement the ADMM TVD solver tvd(),

function tvd(y::AbstractVector, D::AbstractMatrix, λ, ρ, maxit, tol, verbose)
	M, N = size(D);
	objfun(x,Dx) = 0.5*sum((x-y).^2) + λ*norm(Dx, 1);
	x = y;
	z = zeros(M);
	u = zeros(M);
	F = zeros(maxit); # objective fun,
	r = zeros(maxit); # primal residual
	s = zeros(maxit); # dual residual

	C = cholesky(I + ρ*D'*D);

	k = 0; 
	while k == 0 || k < maxit && r[k] > tol
		x = C\(y + ρ*D'*(z - u)); # x-update
		Dx = D*x;
		zᵏ = z;
		z = ST.(Dx + u, λ/ρ);     # z-update
		u = u + Dx - z;           # dual ascent
		r[k+1] = norm(Dx - z);
		s[k+1] = ρ*norm(D'*(z - zᵏ));
		F[k+1] = objfun(x,Dx);
		k += 1;
		if verbose
			@printf "k: %3d | F= %.3e | r= %.3e | s= %.3e \n" k F[k] r[k] s[k] ;
		end
	end
	return x, (k=k, obj=F[1:k], pres=r[1:k], dres=s[1:k])
end

In tvd() we precompute our Cholesky factor to have fast xx updates. C is a special object ready for solving linear systems via backslash \. As the above tvd() is defined only for vector inputs, we can overload to define a tvd() function for 2D and 3D inputs.

function tvd(y::AbstractArray, λ, ρ=1; maxit=100, tol=1e-6, verbose=true)
	sz = size(y);
	D = FDmat(sz...);
	x, hist = tvd(vec(y), D, λ, ρ, maxit, tol, verbose);
	return reshape(x, sz...), hist
end

This has turned out so clean because we've already overloaded FDmat() to return the correct TV operators. Note that the "splat" operator (...) unpacks tuples and more (see ? ...).

The Fast (Fourier Transform) Way

The assumption of periodicity allows us to diagonalize our linear system in the Fourier domain and solve much faster. We first need to create our own circular convolution function, as circular padding isn't implemented in NNlib.jl.

function pad_circular(A::Array{<:Number,4}, pad::NTuple{4,Int})
	M, N = size(A)[1:2]
	if any(pad .> M) || any(pad .> N)
		error("padding larger than original matrix!")
	end
	# allocate array
	B = pad_constant(A, pad, dims=(1,2))
	t, b, l, r = pad
	f(p,L) = L-p+1:L
	# top-left lorner
	B[1:t, 1:l, :, :]               = A[f(t,M), f(l,N), :, :]
	# top-middle
	B[1:t, l+1:l+N, :, :]           = A[f(t,M), :, :, :]
	# top-right
	B[1:t, f(r,N+l+r), :, :]        = A[f(t,M), 1:r, :, :]
	# left-middle
	B[t+1:t+M, 1:l, :, :]           = A[:, f(l,N), :, :]
	# right-middle
	B[t+1:t+M, f(r,N+l+r), :, :]    = A[:, 1:r, :, :]
	# bottom-left
	B[f(b,M+b+t), 1:l, :, :]        = A[1:b, f(l,N), :, :]
	# bottom-middle
	B[f(b,M+t+b), l+1:l+N, :, :]    = A[1:b, :, :, :]
	# bottom-right
	B[f(b,M+t+b), f(r,N+l+r), :, :] = A[1:b, 1:r, :, :]
	return B
end

The definition of pad_circular() gives an example of how to discriminate Array inputs based on the number of dimension. The crucial part is defining the element-type to allow any subtype of Number via the subtype operator <:. Also note that NTuple{4,Int} is short for Tuple{Int,Int,Int,Int}. A good way to check if your arguments are labeled as desired is using isa(),

A = reshape(collect(1:9), 3,3);
println(isa(A, Array{Number, 2}))
isa(A, Array{<:Number, 2})
false
true

Use typeof() and eltype() to get the type of an array and its elements respectively.

We now can write the ADMM Fourier domain solver.

function tvd_fft(y::Array{<:Real,4}, λ, ρ=1; maxit=100, tol=1e-6, verbose=true)
	M, N, P = size(y)[1:3]
	# move channels to batch dimension
	y = permutedims(y, (1,2,4,3))
	τ = λ/ρ;

	# precompute C for x-update
	Λx = rfft([1 -1 zeros(N-2)'; zeros(M-1,N)]);
	Λy = rfft([[1; -1; zeros(M-2)] zeros(M,N-1)])
	C = 1 ./ ( 1 .+ ρ.*(abs2.(Λx) .+ abs2.(Λy)) );

	# real Fourier xfrm in image dimension.
	# Must specify length of first dimension for inverse.
	Q  = plan_rfft(y,(1,2));
	Qᴴ = plan_irfft(rfft(y),M,(1,2));

    # conv kernel
	W = Float64.(zeros(2,2,1,2));
	W[:,:,1,1] = [1 -1; 0 0];          # dx
	W[:,:,1,2] = [1  0;-1 0];          # dy
	Wᵀ = reverse(permutedims(W, (2,1,4,3)), dims=:);

	# Circular convolution
	D(x) = conv(pad_circular(x, (1,0,1,0)), W);
	Dᵀ(x)= conv(pad_circular(x, (0,1,0,1)), Wᵀ);

	objfun(x,Dx) = 0.5*sum(abs2.(x-y)) + λ*norm(Dx, 1);

	# initialization
	z = zeros(M,N,2,P);
	u = zeros(M,N,2,P);
	F = zeros(maxit); # store objective fun
	r = zeros(maxit); # store primal residual
	s = zeros(maxit); # store dual   residual

	k = 0;
	while k == 0 || k < maxit && r[k] > tol && s[k] > tol
		x = Qᴴ*(C.*(Q*( y + ρ*Dᵀ(z-u) ))); # x update
		Dxᵏ = D(x);
		zᵏ  = z;
		z = ST.(Dxᵏ+u, τ);                 # z update
		u = u + Dxᵏ - z;                   # dual ascent
		r[k+1] = norm(Dxᵏ-z);
		s[k+1] = norm(z-zᵏ);
		F[k+1] = objfun(x,Dxᵏ);
		k += 1;
		if verbose
			@printf "k: %3d | F= %.3e | r= %.3e | s= %.3e \n" k F[k] r[k] s[k];
		end
	end
	x = permutedims(x, (1,2,4,3));
	return x, (k=k, obj=F[1:k], pres=r[1:k], dres=s[1:k])
end

Circular convolution is implemented by our circular padding followed by NNlib's conv.

Note: NNlib's convolution function
conv implements valid convolution, not correlation. Correlation can be run instead by using the kwarg flipped=true. The supplied kernel W is given by height, width, in-channels, and out-channels as [H,W,Cin,Cout][H,W,C_{\mathrm{in}},C_{\mathrm{out}}]. 1D, 2D, 3D signals all use the same conv function. The expression for Wᵀ is especially insightful – filters must be spatially reversed and have their ordering reversed.

To perform xx and yy gradient calculations in the same conv function we have to pay attention to padding. Our horizontal filter shouldn't touch our top padding and our vertical filter shouldn't touch the left padding. Therefore, by also taking into account that convolution is a flip and slide, we write the kernel WW's two filters (in output dimension) as,

W[:,:,1,1]=[1100]W[:,:,1,2]=[1010]. W[:,:,1,1] = \begin{bmatrix} 1 & -1 \\ 0 & 0 \end{bmatrix} \quad W[:,:,1,2] = \begin{bmatrix} 1 & 0 \\-1 & 0 \end{bmatrix}.

To precompute our Fourier factor CC, we pad our filters to yeild the NN-dimensional transform domain representations. We can use rfft to exploit the conjugate symmetry of our real-valued filters and only compute half the coefficients.

Our last point of interest is the use of planned FFTs. These return operators that use matrix multiplication syntax, but are not performing dense matrix-vector multiplications. The benefit in doing so is that the butterfly network used to compute our FFT is saved in memory for faster computation. The first argument of each plan function above is meant to give the size and data-type of inputs we'll use in the loop.

Comparison

A full demo code can be found in the repository. Code shown below is not given in the repository.

We demonstrate our working TV denoising on the example Fabio image with AWGN, σn=25\sigma_n = 25 (on a 0-255 scale).

 The two methods acheive similar PSNR results.
The two methods acheive similar PSNR results.

The similarity of results is to be expected because their solving almost exactly the same problem. Where we'll find gains is in our run times.

I ran some rudimentary timing tests on different versions of the Fabio image and hard-coded results in the visualization script below. The FFT version is seriously fast in comparison.

using Plots, StatsPlots
t1 = [3.3, 14.2, 24.5, 101.0] # tvd times
t2 = [0.5, 1.6,  2.8,  10.8]  # tvd_fft times
name = repeat(["Gray 256", "Color 256", "Gray 512", "Color 512"], outer=2)
group= repeat(["Sparse", "FFT"], inner=4)
groupedbar(name, [t1;t2], group=group)
ylabel!("time (s)")
savefig("timebar.svg")
 FFT vs. sparse-array timing comparison.
FFT vs. sparse-array timing comparison.

As a last example, I'll demonstrate that TV denoising is not just a trick for synthetically generated noisy images. Below is a real noisy photo of Masa.

using TVDenoise, FileIO
img = load("masa.png")
x, _ = tvd_fft(img, 0.02, 5)
save("tvdmasa.png", x)

Our TV denoising is able to clean it up nicely.

One question that persists throughout this demos is, "Is there a different set of filters that are better tailored to regularizing the denoising problem?". We can formulate our problem as also being a minimization over D\mathbf{D}, known as the Analysis Convolutional Dictionary Learning problem. Special care has to be taken to ensure that D\mathbf{D} is not degenerate. A Julia implementation for another post!

[1] See Understanding ISTA as a Fixed-Point Iteration for a derivation of soft-thresholding.