From 36cfd8ad528c3a0a4bbc8be3fb6be036b868dc8c Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Sat, 10 Jul 2021 19:15:46 +0200
Subject: [PATCH] implement error estimator

---
 src/run.jl | 61 +++++++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 56 insertions(+), 5 deletions(-)

diff --git a/src/run.jl b/src/run.jl
index 2f6dc05..9fc102f 100644
--- a/src/run.jl
+++ b/src/run.jl
@@ -1,4 +1,4 @@
-export myrun, denoise, inpaint, optflow
+export myrun, denoise, inpaint, optflow, solve_primal!, estimate!
 
 using LinearAlgebra: norm
 
@@ -20,6 +20,7 @@ struct L1L2TVContext{M,Ttype,Stype}
     gamma1::Float64
     gamma2::Float64
 
+    est::FeFunction
     g::FeFunction
     u::FeFunction
     p1::FeFunction
@@ -35,11 +36,13 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
 	alpha1, alpha2, beta, lambda, gamma1, gamma2)
     d = ndims_domain(mesh)
 
+    Vest = FeSpace(mesh, DP0(), (1,))
     Vg = FeSpace(mesh, P1(), (1,))
     Vu = FeSpace(mesh, P1(), (m,))
     Vp1 = FeSpace(mesh, DP0(), (1,))
     Vp2 = FeSpace(mesh, DP1(), (m, d))
 
+    est = FeFunction(Vest, name="est")
     g = FeFunction(Vg, name="g")
     u = FeFunction(Vu, name="u")
     p1 = FeFunction(Vp1, name="p1")
@@ -50,6 +53,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
     nablau = nabla(u)
     nabladu = nabla(du)
 
+    est.data .= 0
     g.data .= 0
     u.data .= 0
     p1.data .= 0
@@ -60,7 +64,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
 
     return L1L2TVContext(name, mesh, d, m, T, tdata, S,
 	alpha1, alpha2, beta, lambda, gamma1, gamma2,
-	g, u, p1, p2, du, dp1, dp2, nablau, nabladu)
+	est, g, u, p1, p2, du, dp1, dp2, nablau, nabladu)
 end
 
 function step!(ctx::L1L2TVContext)
@@ -157,6 +161,45 @@ function step!(ctx::L1L2TVContext)
     p2_project!(ctx.p2, ctx.lambda)
 end
 
+function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
+    u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
+	ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
+	    ctx.beta * dot(ctx.S(u, nablau), ctx.S(phi, nablaphi))
+
+    u_l(x, phi, nablaphi; g, p1, p2, tdata) =
+	-dot(p1, ctx.T(tdata, phi)) - dot(p2, nablaphi) +
+	    ctx.alpha2 * dot(g, ctx.T(tdata, phi))
+
+    # u = B^{-1} * (T^* p_1 - div p_2 - alpha2 * T^* g)
+    A, b = assemble(u.space, u_a, u_l; ctx.g, ctx.p1, ctx.p2, ctx.tdata)
+    u.data .= A \ b
+end
+
+function estimate!(ctx::L1L2TVContext)
+    huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
+
+    function estf(x; g, u, p1, p2, nablau, w, nablaw, tdata)
+	alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * (
+	    huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
+	    dot(ctx.T(tdata, u) - g, p1 / ctx.alpha1) +
+	    ctx.gamma1 / 2 * norm(p1 / ctx.alpha1)^2)
+	lambdapart = iszero(ctx.lambda) ? 0. : ctx.lambda * (
+	    huber(norm(nablau), ctx.gamma2) +
+	    dot(nablau, p2 / ctx.lambda) +
+	    ctx.gamma2 / 2 * norm(p2 / ctx.lambda)^2)
+	bpart = 1 / 2 * (
+	    ctx.alpha2 * dot(ctx.T(tdata, w - u), ctx.T(tdata, w - u)) +
+	    ctx.beta * dot(ctx.S(w, nablaw) - ctx.S(u, nablau), ctx.S(w, nablaw) - ctx.S(u, nablau)))
+
+	return alpha1part + lambdapart + bpart
+    end
+
+    w = FeFunction(ctx.u.space)
+    nablaw = nabla(w)
+    solve_primal!(w, ctx)
+    interpolate!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
+end
+
 function save(ctx::L1L2TVContext, filename, fs...)
     print("save ... ")
     vtk = vtk_mesh(filename, ctx.mesh)
@@ -175,18 +218,22 @@ function denoise(img; name, params...)
     ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
 
     interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
+    m = size(img) ./ 2
+    interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
 
     save_denoise(i) =
 	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
-	    ctx.g, ctx.u, ctx.p1, ctx.p2)
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
 
     pvd = paraview_collection("$(ctx.name).pvd")
     pvd[0] = save_denoise(0)
     for i in 1:10
 	step!(ctx)
+	estimate!(ctx)
 	pvd[i] = save_denoise(i)
 	println()
     end
+    return ctx
 end
 
 function inpaint(img, imgmask; name, params...)
@@ -212,15 +259,17 @@ function inpaint(img, imgmask; name, params...)
 
     save_inpaint(i) =
 	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
-	    ctx.g, ctx.u, ctx.p1, ctx.p2, mask)
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask)
 
     pvd = paraview_collection("$(ctx.name).pvd")
     pvd[0] = save_inpaint(0)
     for i in 1:10
 	step!(ctx)
+	estimate!(ctx)
 	pvd[i] = save_inpaint(i)
 	println()
     end
+    return ctx
 end
 
 function optflow(imgf0, imgf1; name, params...)
@@ -253,13 +302,15 @@ function optflow(imgf0, imgf1; name, params...)
 
     save_optflow(i) =
 	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
-	    ctx.g, ctx.u, ctx.p1, ctx.p2, f0, f1, fw)
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, f0, f1, fw)
 
     pvd = paraview_collection("$(ctx.name).pvd")
     pvd[0] = save_optflow(0)
     for i in 1:10
 	step!(ctx)
+	estimate!(ctx)
 	pvd[i] = save_optflow(i)
 	println()
     end
+    return ctx
 end
-- 
GitLab