From d4d67b01a7495ba88c8ff67cdcbbdb7da6c8b120 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Sun, 25 Jul 2021 08:27:22 +0200
Subject: [PATCH] implement primal dual algorithm (untested)

---
 src/operator.jl |  37 +++++++-----
 src/run.jl      | 150 +++++++++++++++++++++++++++++++++++++++++++-----
 2 files changed, 159 insertions(+), 28 deletions(-)

diff --git a/src/operator.jl b/src/operator.jl
index bc27855..b15cd60 100644
--- a/src/operator.jl
+++ b/src/operator.jl
@@ -116,9 +116,12 @@ function assemble(space::FeSpace, a, l; params...)
     return A, b
 end
 
+project_img(space::FeSpace, img) =
+    (u = FeFunction(space); project_img!(u, img))
 
-function project_img(space::FeSpace, img)
+function project_img!(u::FeFunction, img)
     d = 2 # domain dimension
+    space = u.space
     mesh = space.mesh
     f = ImageFunction(mesh, img)
     opparams = (; f)
@@ -131,13 +134,21 @@ function project_img(space::FeSpace, img)
 
     # composite midpoint quadrature on lagrange point lattice
     function quadrature(p)
-	k = Iterators.filter(x -> sum(x) == p,
-	    Iterators.product((0:p for _ in 1:d+1)...)) |> collect
-
-	weights = [1 / length(k) for _ in axes(k, 1)]
-	points = [x[i] / p for i in 1:2, x in k]
+	d_ = 2
+	n = binomial(p + 2, 2)
+	weights = Vector{Float64}(undef, n)
+	points = Matrix{Float64}(undef, 2, n)
+
+	k = 0
+	for I in Iterators.product(ntuple(_ -> 0:p, d_ + 1)...)
+	    I[1] + I[2] + I[3] != p && continue
+	    k += 1
+	    weights[k] = 1 / n
+	    points[1, k] = I[1] / p
+	    points[2, k] = I[2] / p
+	end
 
-	return weights::Vector{Float64}, points::Matrix{Float64}
+	return weights, points
     end
 
     I = Float64[]
@@ -160,16 +171,16 @@ function project_img(space::FeSpace, img)
 
 	qphi = zeros(nrdims, nrdims, nldofs, nqpts)
 	dqphi = zeros(nrdims, d, nrdims, nldofs, nqpts)
-	for r in 1:nrdims
-	    for k in axes(qx, 2)
-		qphi[r, r, :, k] .= evaluate_basis(space.element, qx[:, k])
-		dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(qx[:, k])))
+	for k in axes(qx, 2)
+	    for r in 1:nrdims
+		qphi[r, r, :, k] .= evaluate_basis(space.element, SVector{d}(view(qx, :, k)))
+		dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(view(qx, :, k))))
 	    end
 	end
 
 	# quadrature points
 	for k in axes(qx, 2)
-	    xhat = SVector{d}(qx[:, k])
+	    xhat = SVector{d}(view(qx, :, k))
 	    x = elmap(mesh, cell)(xhat)
 	    opvalues = map(f -> evaluate(f, xhat), opparams)
 
@@ -204,8 +215,6 @@ function project_img(space::FeSpace, img)
     ngdofs = ndofs(space)
     A = sparse(I, J, V, ngdofs, ngdofs)
 
-    u = FeFunction(space)
     u.data .= A \ b
-
     return u
 end
diff --git a/src/run.jl b/src/run.jl
index a921111..dee7488 100644
--- a/src/run.jl
+++ b/src/run.jl
@@ -1,4 +1,4 @@
-export myrun, denoise, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
+export myrun, denoise, denoise_pd, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
 
 using LinearAlgebra: norm
 
@@ -169,6 +169,74 @@ function step!(ctx::L1L2TVContext)
     p2_project!(ctx.p2, ctx.lambda)
 end
 
+function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
+    # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
+    if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0.
+	error("unsupported parameters")
+    end
+    beta = tau * ctx.alpha1 / (1 + 2 * tau * ctx.alpha2)
+
+    # u is P1
+    # p2 is essentially DP0 (technically may be DP1)
+
+    # 1.
+    function p2_update(x_; p2, nablau)
+	return p2 + sigma * nablau
+    end
+    interpolate!(ctx.p2, p2_update; ctx.p2, ctx.nablau)
+
+    function p2_project!(p2, lambda)
+	p2.space.element::DP1
+	p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
+	for i in axes(p2d, 2)
+	    p2in = norm(p2d[:, i])
+	    if p2in > lambda
+		p2d[:, i] .*= lambda ./ p2in
+	    end
+	end
+    end
+    p2_next = FeFunction(ctx.p2.space)
+    p2_project!(p2_next, ctx.lambda)
+    ctx.dp2.data .= p2_next.data .- ctx.p2.data
+    ctx.p2.data .= p2_next.data
+
+    # 2.
+    u_a(x, z, nablaz, phi, nablaphi; g, u, p2) =
+	dot(z, phi)
+
+    u_l(x, phi, nablaphi; u, g, p2) =
+	(dot(u + 2 * tau * ctx.alpha2 * g, phi) - tau * dot(p2, nablaphi)) /
+	    (1 + 2 * tau * ctx.alpha2)
+
+    # z = 1 / (1 + 2 * tau * alpha2) *
+    #   (u + 2 * tau * alpha2 * g + tau * div(p))
+    z = FeFunction(ctx.u.space)
+    A, b = assemble(z.space, u_a, u_l; ctx.g, ctx.u, ctx.p2)
+    z.data .= A \ b
+
+    function u_update!(u, z, g, beta)
+	u.space.element::P1
+	g.space.element::P1
+	for i in eachindex(u.data)
+	    if z.data[i] - beta >= g.data[i]
+		u.data[i] = z.data[i] - beta
+	    elseif z.data[i] + beta <= g.data[i]
+		u.data[i] = z.data[i] + beta
+	    else
+		u.data[i] = g.data[i]
+	    end
+	end
+    end
+    u_next = FeFunction(ctx.u.space)
+    u_update!(u_next, z, ctx.g, beta)
+
+    # 3.
+    ctx.du.data .= u_next.data .- ctx.u.data
+    ctx.u.data .= u_next.data .+ theta * ctx.du.data
+
+    return ctx
+end
+
 function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
     u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
 	ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
@@ -208,15 +276,20 @@ function estimate!(ctx::L1L2TVContext)
     project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
 end
 
-function refine(ctx::L1L2TVContext, marked_cells)
+function refine(ctx::L1L2TVContext, marked_cells; fs_...)
+    fs = NamedTuple(fs_)
+
     hmesh = HMesh(ctx.mesh)
     refined_functions = refine!(hmesh, Set(marked_cells);
-	ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2)
+	ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2,
+	fs...)
     new_mesh = refined_functions.u.space.mesh
 
     new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
 	ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
 
+    fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs))
+
     @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
     @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
     @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
@@ -235,7 +308,7 @@ function refine(ctx::L1L2TVContext, marked_cells)
     new_ctx.dp1.data .= refined_functions.dp1.data
     new_ctx.dp2.data .= refined_functions.dp2.data
 
-    return new_ctx
+    return new_ctx, fs_new
 end
 
 function mark(ctx::L1L2TVContext; theta=0.5)
@@ -280,6 +353,9 @@ end
 
 norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
 
+norm_step(ctx::L1L2TVContext) =
+    sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh))
+
 function denoise(img; name, params...)
     m = 1
     #mesh = init_grid(img; type=:vertex)
@@ -290,8 +366,9 @@ function denoise(img; name, params...)
 
     ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
 
-    interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
-    m = (size(img) .- 1) ./ 2 .+ 1
+    project_img!(ctx.g, img)
+    #interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
+    #m = (size(img) .- 1) ./ 2 .+ 1
     #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
 
     save_denoise(ctx, i) =
@@ -311,24 +388,69 @@ function denoise(img; name, params...)
 	    pvd[k] = save_denoise(ctx, k)
 	    println()
 
-	    norm_step = sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(mesh))
+	    norm_step_ = norm_step(ctx)
 
 	    println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
 	    println("primal energy: $(primal_energy(ctx))")
-	    println("norm_step: $(norm_step)")
+	    println("norm_step: $(norm_step_)")
 
-            norm_step <= 1e-1 && break
+            norm_step_ <= 1e-1 && break
 	end
 	marked_cells = mark(ctx; theta = 0.5)
-	#println(marked_cells)
 	println("refining ...")
-	ctx = refine(ctx, marked_cells)
+	ctx, _ = refine(ctx, marked_cells)
 	test_mesh(ctx.mesh)
 
-	gnew = project_img(ctx.g.space, img)
-	ctx.g.data .= gnew.data
-	#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
+	project_img!(ctx.g, img)
+
+	k >= 100 && break
+    end
+    vtk_save(pvd)
+    return ctx
+end
+
+function denoise_pd(img; name, params...)
+    m = 1
+    mesh = init_grid(img; type=:vertex)
+    #mesh = init_grid(img, 5, 5)
+
+    sigma = 1e-1
+    tau = 1e-1
+    theta = 1.
+
+    T(tdata, u) = u
+    S(u, nablau) = u
+
+    ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
+
+    project_img!(ctx.g, img)
+    #interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
+    #m = (size(img) .- 1) ./ 2 .+ 1
+    #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
+
+    save_denoise(ctx, i) =
+	output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
+
+    pvd = paraview_collection("output/$(ctx.name).pvd")
+    pvd[0] = save_denoise(ctx, 0)
+
+    k = 0
+    println("primal energy: $(primal_energy(ctx))")
+    while true
+	k += 1
+	step_pd!(ctx; sigma, tau, theta)
+	#estimate!(ctx)
+	pvd[k] = save_denoise(ctx, k)
+	println()
+
+	norm_step_ = norm_step(ctx)
+
+	println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
+	println("primal energy: $(primal_energy(ctx))")
+	println("norm_step: $(norm_step_)")
 
+	norm_step_ <= 1e-1 && break
 	k >= 100 && break
     end
     vtk_save(pvd)
-- 
GitLab