From 36cfd8ad528c3a0a4bbc8be3fb6be036b868dc8c Mon Sep 17 00:00:00 2001 From: Stephan Hilb <stephan@ecshi.net> Date: Sat, 10 Jul 2021 19:15:46 +0200 Subject: [PATCH] implement error estimator --- src/run.jl | 61 +++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/src/run.jl b/src/run.jl index 2f6dc05..9fc102f 100644 --- a/src/run.jl +++ b/src/run.jl @@ -1,4 +1,4 @@ -export myrun, denoise, inpaint, optflow +export myrun, denoise, inpaint, optflow, solve_primal!, estimate! using LinearAlgebra: norm @@ -20,6 +20,7 @@ struct L1L2TVContext{M,Ttype,Stype} gamma1::Float64 gamma2::Float64 + est::FeFunction g::FeFunction u::FeFunction p1::FeFunction @@ -35,11 +36,13 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, alpha1, alpha2, beta, lambda, gamma1, gamma2) d = ndims_domain(mesh) + Vest = FeSpace(mesh, DP0(), (1,)) Vg = FeSpace(mesh, P1(), (1,)) Vu = FeSpace(mesh, P1(), (m,)) Vp1 = FeSpace(mesh, DP0(), (1,)) Vp2 = FeSpace(mesh, DP1(), (m, d)) + est = FeFunction(Vest, name="est") g = FeFunction(Vg, name="g") u = FeFunction(Vu, name="u") p1 = FeFunction(Vp1, name="p1") @@ -50,6 +53,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, nablau = nabla(u) nabladu = nabla(du) + est.data .= 0 g.data .= 0 u.data .= 0 p1.data .= 0 @@ -60,7 +64,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, return L1L2TVContext(name, mesh, d, m, T, tdata, S, alpha1, alpha2, beta, lambda, gamma1, gamma2, - g, u, p1, p2, du, dp1, dp2, nablau, nabladu) + est, g, u, p1, p2, du, dp1, dp2, nablau, nabladu) end function step!(ctx::L1L2TVContext) @@ -157,6 +161,45 @@ function step!(ctx::L1L2TVContext) p2_project!(ctx.p2, ctx.lambda) end +function solve_primal!(u::FeFunction, ctx::L1L2TVContext) + u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) = + ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) + + ctx.beta * dot(ctx.S(u, nablau), ctx.S(phi, nablaphi)) + + u_l(x, phi, nablaphi; g, p1, p2, tdata) = + -dot(p1, ctx.T(tdata, phi)) - dot(p2, nablaphi) + + ctx.alpha2 * dot(g, ctx.T(tdata, phi)) + + # u = B^{-1} * (T^* p_1 - div p_2 - alpha2 * T^* g) + A, b = assemble(u.space, u_a, u_l; ctx.g, ctx.p1, ctx.p2, ctx.tdata) + u.data .= A \ b +end + +function estimate!(ctx::L1L2TVContext) + huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2 + + function estf(x; g, u, p1, p2, nablau, w, nablaw, tdata) + alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * ( + huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) + + dot(ctx.T(tdata, u) - g, p1 / ctx.alpha1) + + ctx.gamma1 / 2 * norm(p1 / ctx.alpha1)^2) + lambdapart = iszero(ctx.lambda) ? 0. : ctx.lambda * ( + huber(norm(nablau), ctx.gamma2) + + dot(nablau, p2 / ctx.lambda) + + ctx.gamma2 / 2 * norm(p2 / ctx.lambda)^2) + bpart = 1 / 2 * ( + ctx.alpha2 * dot(ctx.T(tdata, w - u), ctx.T(tdata, w - u)) + + ctx.beta * dot(ctx.S(w, nablaw) - ctx.S(u, nablau), ctx.S(w, nablaw) - ctx.S(u, nablau))) + + return alpha1part + lambdapart + bpart + end + + w = FeFunction(ctx.u.space) + nablaw = nabla(w) + solve_primal!(w, ctx) + interpolate!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata) +end + function save(ctx::L1L2TVContext, filename, fs...) print("save ... ") vtk = vtk_mesh(filename, ctx.mesh) @@ -175,18 +218,22 @@ function denoise(img; name, params...) ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) + m = size(img) ./ 2 + interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3) save_denoise(i) = save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", - ctx.g, ctx.u, ctx.p1, ctx.p2) + ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est) pvd = paraview_collection("$(ctx.name).pvd") pvd[0] = save_denoise(0) for i in 1:10 step!(ctx) + estimate!(ctx) pvd[i] = save_denoise(i) println() end + return ctx end function inpaint(img, imgmask; name, params...) @@ -212,15 +259,17 @@ function inpaint(img, imgmask; name, params...) save_inpaint(i) = save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", - ctx.g, ctx.u, ctx.p1, ctx.p2, mask) + ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask) pvd = paraview_collection("$(ctx.name).pvd") pvd[0] = save_inpaint(0) for i in 1:10 step!(ctx) + estimate!(ctx) pvd[i] = save_inpaint(i) println() end + return ctx end function optflow(imgf0, imgf1; name, params...) @@ -253,13 +302,15 @@ function optflow(imgf0, imgf1; name, params...) save_optflow(i) = save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", - ctx.g, ctx.u, ctx.p1, ctx.p2, f0, f1, fw) + ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, f0, f1, fw) pvd = paraview_collection("$(ctx.name).pvd") pvd[0] = save_optflow(0) for i in 1:10 step!(ctx) + estimate!(ctx) pvd[i] = save_optflow(i) println() end + return ctx end -- GitLab