Skip to content
Snippets Groups Projects
Commit 36cfd8ad authored by Stephan Hilb's avatar Stephan Hilb
Browse files

implement error estimator

parent 286238f6
Branches
Tags
No related merge requests found
export myrun, denoise, inpaint, optflow export myrun, denoise, inpaint, optflow, solve_primal!, estimate!
using LinearAlgebra: norm using LinearAlgebra: norm
...@@ -20,6 +20,7 @@ struct L1L2TVContext{M,Ttype,Stype} ...@@ -20,6 +20,7 @@ struct L1L2TVContext{M,Ttype,Stype}
gamma1::Float64 gamma1::Float64
gamma2::Float64 gamma2::Float64
est::FeFunction
g::FeFunction g::FeFunction
u::FeFunction u::FeFunction
p1::FeFunction p1::FeFunction
...@@ -35,11 +36,13 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, ...@@ -35,11 +36,13 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
alpha1, alpha2, beta, lambda, gamma1, gamma2) alpha1, alpha2, beta, lambda, gamma1, gamma2)
d = ndims_domain(mesh) d = ndims_domain(mesh)
Vest = FeSpace(mesh, DP0(), (1,))
Vg = FeSpace(mesh, P1(), (1,)) Vg = FeSpace(mesh, P1(), (1,))
Vu = FeSpace(mesh, P1(), (m,)) Vu = FeSpace(mesh, P1(), (m,))
Vp1 = FeSpace(mesh, DP0(), (1,)) Vp1 = FeSpace(mesh, DP0(), (1,))
Vp2 = FeSpace(mesh, DP1(), (m, d)) Vp2 = FeSpace(mesh, DP1(), (m, d))
est = FeFunction(Vest, name="est")
g = FeFunction(Vg, name="g") g = FeFunction(Vg, name="g")
u = FeFunction(Vu, name="u") u = FeFunction(Vu, name="u")
p1 = FeFunction(Vp1, name="p1") p1 = FeFunction(Vp1, name="p1")
...@@ -50,6 +53,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, ...@@ -50,6 +53,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
nablau = nabla(u) nablau = nabla(u)
nabladu = nabla(du) nabladu = nabla(du)
est.data .= 0
g.data .= 0 g.data .= 0
u.data .= 0 u.data .= 0
p1.data .= 0 p1.data .= 0
...@@ -60,7 +64,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, ...@@ -60,7 +64,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
return L1L2TVContext(name, mesh, d, m, T, tdata, S, return L1L2TVContext(name, mesh, d, m, T, tdata, S,
alpha1, alpha2, beta, lambda, gamma1, gamma2, alpha1, alpha2, beta, lambda, gamma1, gamma2,
g, u, p1, p2, du, dp1, dp2, nablau, nabladu) est, g, u, p1, p2, du, dp1, dp2, nablau, nabladu)
end end
function step!(ctx::L1L2TVContext) function step!(ctx::L1L2TVContext)
...@@ -157,6 +161,45 @@ function step!(ctx::L1L2TVContext) ...@@ -157,6 +161,45 @@ function step!(ctx::L1L2TVContext)
p2_project!(ctx.p2, ctx.lambda) p2_project!(ctx.p2, ctx.lambda)
end end
function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
ctx.beta * dot(ctx.S(u, nablau), ctx.S(phi, nablaphi))
u_l(x, phi, nablaphi; g, p1, p2, tdata) =
-dot(p1, ctx.T(tdata, phi)) - dot(p2, nablaphi) +
ctx.alpha2 * dot(g, ctx.T(tdata, phi))
# u = B^{-1} * (T^* p_1 - div p_2 - alpha2 * T^* g)
A, b = assemble(u.space, u_a, u_l; ctx.g, ctx.p1, ctx.p2, ctx.tdata)
u.data .= A \ b
end
function estimate!(ctx::L1L2TVContext)
huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
function estf(x; g, u, p1, p2, nablau, w, nablaw, tdata)
alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * (
huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
dot(ctx.T(tdata, u) - g, p1 / ctx.alpha1) +
ctx.gamma1 / 2 * norm(p1 / ctx.alpha1)^2)
lambdapart = iszero(ctx.lambda) ? 0. : ctx.lambda * (
huber(norm(nablau), ctx.gamma2) +
dot(nablau, p2 / ctx.lambda) +
ctx.gamma2 / 2 * norm(p2 / ctx.lambda)^2)
bpart = 1 / 2 * (
ctx.alpha2 * dot(ctx.T(tdata, w - u), ctx.T(tdata, w - u)) +
ctx.beta * dot(ctx.S(w, nablaw) - ctx.S(u, nablau), ctx.S(w, nablaw) - ctx.S(u, nablau)))
return alpha1part + lambdapart + bpart
end
w = FeFunction(ctx.u.space)
nablaw = nabla(w)
solve_primal!(w, ctx)
interpolate!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
end
function save(ctx::L1L2TVContext, filename, fs...) function save(ctx::L1L2TVContext, filename, fs...)
print("save ... ") print("save ... ")
vtk = vtk_mesh(filename, ctx.mesh) vtk = vtk_mesh(filename, ctx.mesh)
...@@ -175,18 +218,22 @@ function denoise(img; name, params...) ...@@ -175,18 +218,22 @@ function denoise(img; name, params...)
ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
m = size(img) ./ 2
interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
save_denoise(i) = save_denoise(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2) ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("$(ctx.name).pvd") pvd = paraview_collection("$(ctx.name).pvd")
pvd[0] = save_denoise(0) pvd[0] = save_denoise(0)
for i in 1:10 for i in 1:10
step!(ctx) step!(ctx)
estimate!(ctx)
pvd[i] = save_denoise(i) pvd[i] = save_denoise(i)
println() println()
end end
return ctx
end end
function inpaint(img, imgmask; name, params...) function inpaint(img, imgmask; name, params...)
...@@ -212,15 +259,17 @@ function inpaint(img, imgmask; name, params...) ...@@ -212,15 +259,17 @@ function inpaint(img, imgmask; name, params...)
save_inpaint(i) = save_inpaint(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, mask) ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask)
pvd = paraview_collection("$(ctx.name).pvd") pvd = paraview_collection("$(ctx.name).pvd")
pvd[0] = save_inpaint(0) pvd[0] = save_inpaint(0)
for i in 1:10 for i in 1:10
step!(ctx) step!(ctx)
estimate!(ctx)
pvd[i] = save_inpaint(i) pvd[i] = save_inpaint(i)
println() println()
end end
return ctx
end end
function optflow(imgf0, imgf1; name, params...) function optflow(imgf0, imgf1; name, params...)
...@@ -253,13 +302,15 @@ function optflow(imgf0, imgf1; name, params...) ...@@ -253,13 +302,15 @@ function optflow(imgf0, imgf1; name, params...)
save_optflow(i) = save_optflow(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, f0, f1, fw) ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, f0, f1, fw)
pvd = paraview_collection("$(ctx.name).pvd") pvd = paraview_collection("$(ctx.name).pvd")
pvd[0] = save_optflow(0) pvd[0] = save_optflow(0)
for i in 1:10 for i in 1:10
step!(ctx) step!(ctx)
estimate!(ctx)
pvd[i] = save_optflow(i) pvd[i] = save_optflow(i)
println() println()
end end
return ctx
end end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment