From 9c8ec4af3abc302a6cc875413441e8e75d5bd9a8 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Mon, 26 Jul 2021 12:50:44 +0200
Subject: [PATCH] implement new chambolle-pock algorithm with different
 dualisation

---
 src/function.jl |  50 ++++++++++++++
 src/run.jl      | 178 ++++++++++++++++++++++++++++++++++++++++++------
 2 files changed, 208 insertions(+), 20 deletions(-)

diff --git a/src/function.jl b/src/function.jl
index ed4dddf..af194b7 100644
--- a/src/function.jl
+++ b/src/function.jl
@@ -267,6 +267,56 @@ evaluate(f::ImageFunction, xloc) =
     interpolate_bilinear(f.img, elmap(f.mesh, f.cell[])(xloc))
 
 
+struct FacetDivergence{F}
+    f::F
+    cell::Base.RefValue{Int}
+    edgemap::Dict{NTuple{2, Int}, Vector{Int}}
+end
+
+# TODO: incomplete!
+function FacetDivergence(f::FeFunction)
+    f.space.element == DP0() || throw(ArgumentError("unimplemented"))
+    mesh = f.space.mesh
+    edgedivergence = Dict{NTuple{2, Int}, Float64}()
+    edgemap = Dict{NTuple{2, Int}, Vector{Int}}()
+
+    for cell in cells(mesh)
+	vs = sort(SVector(vertices(mesh, cell)))
+
+	e1 = (vs[1], vs[2])
+	e2 = (vs[1], vs[3])
+	e3 = (vs[2], vs[3])
+
+	edgemap[e1] = push!(get!(edgemap, e1, []), cell)
+	edgemap[e2] = push!(get!(edgemap, e2, []), cell)
+	edgemap[e3] = push!(get!(edgemap, e3, []), cell)
+
+	bind!(f, cell)
+	p = evaluate(f, SA[0., 0.])
+
+	A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
+	    view(mesh.vertices, :, view(mesh.cells, :, cell)))
+
+	v1 = (A[:,2] - A[:,1])
+	v2 = (A[:,3] - A[:,2])
+	v3 = (A[:,1] - A[:,3])
+
+	p1 = norm(v1) * (p[1] * v1[2] - p[2] * v1[1])
+	p2 = norm(v2) * (p[1] * v2[2] - p[2] * v2[1])
+	p3 = norm(v3) * (p[1] * v3[2] - p[2] * v3[1])
+
+	edgedivergence[e1] = get!(edgedivergence, e1, 0.) + p1
+	edgedivergence[e3] = get!(edgedivergence, e3, 0.) + p2
+	edgedivergence[e2] = get!(edgedivergence, e2, 0.) + p3
+    end
+
+    return FacetDivergence(f, Ref(1), edgemap)
+end
+
+bind!(f::FacetDivergence, cell) = f.cell[] = cell
+
+
+
 function sample(f::FeFunction)
     mesh = f.mapper.mesh
     for cell in cells(mesh)
diff --git a/src/run.jl b/src/run.jl
index dee7488..82249d6 100644
--- a/src/run.jl
+++ b/src/run.jl
@@ -169,8 +169,13 @@ function step!(ctx::L1L2TVContext)
     p2_project!(ctx.p2, ctx.lambda)
 end
 
+"
+2010, Chambolle and Pock: primal-dual semi-implicit algorithm
+2017, Alkämper and Langer: fem dualisation
+"
 function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
     # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
+    # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
     if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0.
 	error("unsupported parameters")
     end
@@ -179,11 +184,16 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
     # u is P1
     # p2 is essentially DP0 (technically may be DP1)
 
-    # 1.
+    # 1. y update
+    u_new = FeFunction(ctx.u.space) # x_bar only used here
+    u_new.data .= ctx.u.data .+ theta .* ctx.du.data
+
+    p2_next = FeFunction(ctx.p2.space)
+
     function p2_update(x_; p2, nablau)
 	return p2 + sigma * nablau
     end
-    interpolate!(ctx.p2, p2_update; ctx.p2, ctx.nablau)
+    interpolate!(p2_next, p2_update; ctx.p2, nablau=nabla(u_new))
 
     function p2_project!(p2, lambda)
 	p2.space.element::DP1
@@ -195,17 +205,17 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
 	    end
 	end
     end
-    p2_next = FeFunction(ctx.p2.space)
     p2_project!(p2_next, ctx.lambda)
     ctx.dp2.data .= p2_next.data .- ctx.p2.data
-    ctx.p2.data .= p2_next.data
+
+    ctx.p2.data .+= ctx.dp2.data
 
     # 2.
     u_a(x, z, nablaz, phi, nablaphi; g, u, p2) =
 	dot(z, phi)
 
     u_l(x, phi, nablaphi; u, g, p2) =
-	(dot(u + 2 * tau * ctx.alpha2 * g, phi) - tau * dot(p2, nablaphi)) /
+	(dot(u + 2 * tau * ctx.alpha2 * g, phi) + tau * dot(p2, -nablaphi)) /
 	    (1 + 2 * tau * ctx.alpha2)
 
     # z = 1 / (1 + 2 * tau * alpha2) *
@@ -227,16 +237,111 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
 	    end
 	end
     end
-    u_next = FeFunction(ctx.u.space)
-    u_update!(u_next, z, ctx.g, beta)
+    u_update!(u_new, z, ctx.g, beta)
 
     # 3.
-    ctx.du.data .= u_next.data .- ctx.u.data
-    ctx.u.data .= u_next.data .+ theta * ctx.du.data
+    # note: step-size control not implemented, since \nabla G^* is not 1/gamma continuous
+    ctx.du.data .= u_new.data .- ctx.u.data
+    ctx.u.data .= u_new.data
+
+    #ctx.du.data .= z.data
+    any(isnan, ctx.u.data) && error("encountered nan data")
+    any(isnan, ctx.p2.data) && error("encountered nan data")
 
     return ctx
 end
 
+"
+2010, Chambolle and Pock: accelerated primal-dual semi-implicit algorithm
+"
+function step_pd2!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
+    # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
+
+    # u is P1
+    # p2 is essentially DP0 (technically may be DP1)
+
+    # 1. y update
+    u_new = FeFunction(ctx.u.space) # x_bar only used here
+    u_new.data .= ctx.u.data .+ theta .* ctx.du.data
+
+    p1_next = FeFunction(ctx.p1.space)
+    p2_next = FeFunction(ctx.p2.space)
+
+    function p1_update(x_; p1, g, u, tdata)
+	ctx.alpha1 == 0 && return zero(p1)
+	return (p1 + sigma * (ctx.T(tdata, u) - g)) /
+	    (1 + ctx.gamma1 * sigma / ctx.alpha1)
+    end
+    interpolate!(p1_next, p1_update; ctx.p1, ctx.g, ctx.u, ctx.tdata)
+
+    function p2_update(x_; p2, nablau)
+	ctx.lambda == 0 && return zero(p2)
+	return (p2 + sigma * nablau) /
+	    (1 + ctx.gamma2 * sigma / ctx.lambda)
+    end
+    interpolate!(p2_next, p2_update; ctx.p2, nablau=nabla(u_new))
+
+    # reproject p1
+    function p1_project!(p1, alpha1)
+	p1.space.element::DP0
+	p1.data .= clamp.(p1.data, -alpha1, alpha1)
+    end
+    p1_project!(p1_next, ctx.alpha1)
+    # reproject p2
+    function p2_project!(p2, lambda)
+	p2.space.element::DP1
+	p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
+	for i in axes(p2d, 2)
+	    p2in = norm(p2d[:, i])
+	    if p2in > lambda
+		p2d[:, i] .*= lambda ./ p2in
+	    end
+	end
+    end
+    p2_project!(p2_next, ctx.lambda)
+
+    ctx.dp1.data .= p1_next.data .- ctx.p1.data
+    ctx.dp2.data .= p2_next.data .- ctx.p2.data
+    ctx.p1.data .+= ctx.dp1.data
+    ctx.p2.data .+= ctx.dp2.data
+
+    # 2. x update
+    u_a(x, w, nablaw, phi, nablaphi; g, u, p1, p2, tdata) =
+	dot(w, phi) +
+	tau * ctx.alpha2 * dot(ctx.T(tdata, w), ctx.T(tdata, phi)) +
+	tau * ctx.beta * dot(ctx.S(w, nablaw), ctx.S(phi, nablaphi))
+
+    u_l(x, phi, nablaphi; g, u, p1, p2, tdata) =
+	dot(u, phi) - tau * (
+	    dot(p1, ctx.T(tdata, phi)) +
+	    dot(p2, nablaphi) -
+	    ctx.alpha2 * dot(g, ctx.T(tdata, phi)))
+
+    A, b = assemble(u_new.space, u_a, u_l; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.tdata)
+    u_new.data .= A \ b
+    ctx.du.data .= u_new.data .- ctx.u.data
+    ctx.u.data .= u_new.data
+
+    #ctx.du.data .= z.data
+    any(isnan, ctx.u.data) && error("encountered nan data")
+    any(isnan, ctx.p1.data) && error("encountered nan data")
+    any(isnan, ctx.p2.data) && error("encountered nan data")
+
+    return ctx
+end
+
+"
+2004, Chambolle: dual semi-implicit algorithm
+"
+function step_d!(ctx::L1L2TVContext; tau)
+    # u is P1
+    # p2 is essentially DP0 (technically may be DP1)
+
+    # TODO: this might not be implementable without higher order elements
+    # need grad(div(p))
+    return ctx
+end
+
 function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
     u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
 	ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
@@ -254,6 +359,7 @@ end
 huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
 
 function estimate!(ctx::L1L2TVContext)
+    # FIXME: sign?
     function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
 	alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * (
 	    huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
@@ -356,6 +462,13 @@ norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
 norm_step(ctx::L1L2TVContext) =
     sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh))
 
+function norm_residual(ctx::L1L2TVContext)
+    w = FeFunction(ctx.u.space)
+    solve_primal!(w, ctx)
+    w.data .-= ctx.u.data
+    return norm_l2(w)
+end
+
 function denoise(img; name, params...)
     m = 1
     #mesh = init_grid(img; type=:vertex)
@@ -389,10 +502,12 @@ function denoise(img; name, params...)
 	    println()
 
 	    norm_step_ = norm_step(ctx)
+	    norm_residual_ = norm_residual(ctx)
 
 	    println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
 	    println("primal energy: $(primal_energy(ctx))")
 	    println("norm_step: $(norm_step_)")
+	    println("norm_residual: $(norm_residual_)")
 
             norm_step_ <= 1e-1 && break
 	end
@@ -414,44 +529,67 @@ function denoise_pd(img; name, params...)
     mesh = init_grid(img; type=:vertex)
     #mesh = init_grid(img, 5, 5)
 
-    sigma = 1e-1
-    tau = 1e-1
-    theta = 1.
-
     T(tdata, u) = u
     S(u, nablau) = u
 
     ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
 
+    # semi-implicit primal dual parameters
+    gamma = ctx.alpha2 + ctx.beta # T = I, S = I
+    sigma = 1e-2
+    tau = 1e-2
+    theta = 0.
+
     project_img!(ctx.g, img)
     #interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
     #m = (size(img) .- 1) ./ 2 .+ 1
     #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
+    ctx.u.data .= ctx.g.data
 
     save_denoise(ctx, i) =
 	output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
-	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, ctx.du)
 
     pvd = paraview_collection("output/$(ctx.name).pvd")
     pvd[0] = save_denoise(ctx, 0)
 
     k = 0
     println("primal energy: $(primal_energy(ctx))")
+
+    algorithm = :newton
     while true
 	k += 1
-	step_pd!(ctx; sigma, tau, theta)
+	#println("")
+	if algorithm == :pd2 #|| primal_energy(ctx) < 2409
+	    println("step_pd: theta = $theta, tau = $tau, sigma = $sigma")
+	    pd = true
+	    theta = 1 / sqrt(1 + 2 * gamma * tau)
+	    tau *= theta
+	    sigma /= theta
+	    step_pd2!(ctx; sigma, tau, theta)
+	elseif algorithm == :pd1
+	    # no step size control
+	    step_pd!(ctx; sigma, tau, theta)
+	elseif algorithm == :newton
+	    step!(ctx)
+	end
 	#estimate!(ctx)
-	pvd[k] = save_denoise(ctx, k)
+	#pvd[k] = save_denoise(ctx, k)
 	println()
 
-	norm_step_ = norm_step(ctx)
+	domain_factor = 1 / sqrt(area(mesh))
+	norm_step_ = norm_step(ctx) * domain_factor
+	residual = -1.
+	residual = norm_residual(ctx) * domain_factor
 
-	println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
+	#println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
 	println("primal energy: $(primal_energy(ctx))")
 	println("norm_step: $(norm_step_)")
+	println("norm_residual: $(residual)")
 
-	norm_step_ <= 1e-1 && break
-	k >= 100 && break
+	#return ctx
+	#residual <= 1e-3 && break
+	#k >= 5 && break
     end
     vtk_save(pvd)
     return ctx
-- 
GitLab