Skip to content
Snippets Groups Projects
Commit 9c8ec4af authored by Stephan Hilb's avatar Stephan Hilb
Browse files

implement new chambolle-pock algorithm with different dualisation

parent d4d67b01
No related branches found
No related tags found
No related merge requests found
...@@ -267,6 +267,56 @@ evaluate(f::ImageFunction, xloc) = ...@@ -267,6 +267,56 @@ evaluate(f::ImageFunction, xloc) =
interpolate_bilinear(f.img, elmap(f.mesh, f.cell[])(xloc)) interpolate_bilinear(f.img, elmap(f.mesh, f.cell[])(xloc))
struct FacetDivergence{F}
f::F
cell::Base.RefValue{Int}
edgemap::Dict{NTuple{2, Int}, Vector{Int}}
end
# TODO: incomplete!
function FacetDivergence(f::FeFunction)
f.space.element == DP0() || throw(ArgumentError("unimplemented"))
mesh = f.space.mesh
edgedivergence = Dict{NTuple{2, Int}, Float64}()
edgemap = Dict{NTuple{2, Int}, Vector{Int}}()
for cell in cells(mesh)
vs = sort(SVector(vertices(mesh, cell)))
e1 = (vs[1], vs[2])
e2 = (vs[1], vs[3])
e3 = (vs[2], vs[3])
edgemap[e1] = push!(get!(edgemap, e1, []), cell)
edgemap[e2] = push!(get!(edgemap, e2, []), cell)
edgemap[e3] = push!(get!(edgemap, e3, []), cell)
bind!(f, cell)
p = evaluate(f, SA[0., 0.])
A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
view(mesh.vertices, :, view(mesh.cells, :, cell)))
v1 = (A[:,2] - A[:,1])
v2 = (A[:,3] - A[:,2])
v3 = (A[:,1] - A[:,3])
p1 = norm(v1) * (p[1] * v1[2] - p[2] * v1[1])
p2 = norm(v2) * (p[1] * v2[2] - p[2] * v2[1])
p3 = norm(v3) * (p[1] * v3[2] - p[2] * v3[1])
edgedivergence[e1] = get!(edgedivergence, e1, 0.) + p1
edgedivergence[e3] = get!(edgedivergence, e3, 0.) + p2
edgedivergence[e2] = get!(edgedivergence, e2, 0.) + p3
end
return FacetDivergence(f, Ref(1), edgemap)
end
bind!(f::FacetDivergence, cell) = f.cell[] = cell
function sample(f::FeFunction) function sample(f::FeFunction)
mesh = f.mapper.mesh mesh = f.mapper.mesh
for cell in cells(mesh) for cell in cells(mesh)
......
...@@ -169,8 +169,13 @@ function step!(ctx::L1L2TVContext) ...@@ -169,8 +169,13 @@ function step!(ctx::L1L2TVContext)
p2_project!(ctx.p2, ctx.lambda) p2_project!(ctx.p2, ctx.lambda)
end end
"
2010, Chambolle and Pock: primal-dual semi-implicit algorithm
2017, Alkämper and Langer: fem dualisation
"
function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.) function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
# note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1! # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
# chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0. if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0.
error("unsupported parameters") error("unsupported parameters")
end end
...@@ -179,11 +184,16 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.) ...@@ -179,11 +184,16 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
# u is P1 # u is P1
# p2 is essentially DP0 (technically may be DP1) # p2 is essentially DP0 (technically may be DP1)
# 1. # 1. y update
u_new = FeFunction(ctx.u.space) # x_bar only used here
u_new.data .= ctx.u.data .+ theta .* ctx.du.data
p2_next = FeFunction(ctx.p2.space)
function p2_update(x_; p2, nablau) function p2_update(x_; p2, nablau)
return p2 + sigma * nablau return p2 + sigma * nablau
end end
interpolate!(ctx.p2, p2_update; ctx.p2, ctx.nablau) interpolate!(p2_next, p2_update; ctx.p2, nablau=nabla(u_new))
function p2_project!(p2, lambda) function p2_project!(p2, lambda)
p2.space.element::DP1 p2.space.element::DP1
...@@ -195,17 +205,17 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.) ...@@ -195,17 +205,17 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
end end
end end
end end
p2_next = FeFunction(ctx.p2.space)
p2_project!(p2_next, ctx.lambda) p2_project!(p2_next, ctx.lambda)
ctx.dp2.data .= p2_next.data .- ctx.p2.data ctx.dp2.data .= p2_next.data .- ctx.p2.data
ctx.p2.data .= p2_next.data
ctx.p2.data .+= ctx.dp2.data
# 2. # 2.
u_a(x, z, nablaz, phi, nablaphi; g, u, p2) = u_a(x, z, nablaz, phi, nablaphi; g, u, p2) =
dot(z, phi) dot(z, phi)
u_l(x, phi, nablaphi; u, g, p2) = u_l(x, phi, nablaphi; u, g, p2) =
(dot(u + 2 * tau * ctx.alpha2 * g, phi) - tau * dot(p2, nablaphi)) / (dot(u + 2 * tau * ctx.alpha2 * g, phi) + tau * dot(p2, -nablaphi)) /
(1 + 2 * tau * ctx.alpha2) (1 + 2 * tau * ctx.alpha2)
# z = 1 / (1 + 2 * tau * alpha2) * # z = 1 / (1 + 2 * tau * alpha2) *
...@@ -227,16 +237,111 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.) ...@@ -227,16 +237,111 @@ function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
end end
end end
end end
u_next = FeFunction(ctx.u.space) u_update!(u_new, z, ctx.g, beta)
u_update!(u_next, z, ctx.g, beta)
# 3. # 3.
ctx.du.data .= u_next.data .- ctx.u.data # note: step-size control not implemented, since \nabla G^* is not 1/gamma continuous
ctx.u.data .= u_next.data .+ theta * ctx.du.data ctx.du.data .= u_new.data .- ctx.u.data
ctx.u.data .= u_new.data
#ctx.du.data .= z.data
any(isnan, ctx.u.data) && error("encountered nan data")
any(isnan, ctx.p2.data) && error("encountered nan data")
return ctx return ctx
end end
"
2010, Chambolle and Pock: accelerated primal-dual semi-implicit algorithm
"
function step_pd2!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
# chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
# u is P1
# p2 is essentially DP0 (technically may be DP1)
# 1. y update
u_new = FeFunction(ctx.u.space) # x_bar only used here
u_new.data .= ctx.u.data .+ theta .* ctx.du.data
p1_next = FeFunction(ctx.p1.space)
p2_next = FeFunction(ctx.p2.space)
function p1_update(x_; p1, g, u, tdata)
ctx.alpha1 == 0 && return zero(p1)
return (p1 + sigma * (ctx.T(tdata, u) - g)) /
(1 + ctx.gamma1 * sigma / ctx.alpha1)
end
interpolate!(p1_next, p1_update; ctx.p1, ctx.g, ctx.u, ctx.tdata)
function p2_update(x_; p2, nablau)
ctx.lambda == 0 && return zero(p2)
return (p2 + sigma * nablau) /
(1 + ctx.gamma2 * sigma / ctx.lambda)
end
interpolate!(p2_next, p2_update; ctx.p2, nablau=nabla(u_new))
# reproject p1
function p1_project!(p1, alpha1)
p1.space.element::DP0
p1.data .= clamp.(p1.data, -alpha1, alpha1)
end
p1_project!(p1_next, ctx.alpha1)
# reproject p2
function p2_project!(p2, lambda)
p2.space.element::DP1
p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
for i in axes(p2d, 2)
p2in = norm(p2d[:, i])
if p2in > lambda
p2d[:, i] .*= lambda ./ p2in
end
end
end
p2_project!(p2_next, ctx.lambda)
ctx.dp1.data .= p1_next.data .- ctx.p1.data
ctx.dp2.data .= p2_next.data .- ctx.p2.data
ctx.p1.data .+= ctx.dp1.data
ctx.p2.data .+= ctx.dp2.data
# 2. x update
u_a(x, w, nablaw, phi, nablaphi; g, u, p1, p2, tdata) =
dot(w, phi) +
tau * ctx.alpha2 * dot(ctx.T(tdata, w), ctx.T(tdata, phi)) +
tau * ctx.beta * dot(ctx.S(w, nablaw), ctx.S(phi, nablaphi))
u_l(x, phi, nablaphi; g, u, p1, p2, tdata) =
dot(u, phi) - tau * (
dot(p1, ctx.T(tdata, phi)) +
dot(p2, nablaphi) -
ctx.alpha2 * dot(g, ctx.T(tdata, phi)))
A, b = assemble(u_new.space, u_a, u_l; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.tdata)
u_new.data .= A \ b
ctx.du.data .= u_new.data .- ctx.u.data
ctx.u.data .= u_new.data
#ctx.du.data .= z.data
any(isnan, ctx.u.data) && error("encountered nan data")
any(isnan, ctx.p1.data) && error("encountered nan data")
any(isnan, ctx.p2.data) && error("encountered nan data")
return ctx
end
"
2004, Chambolle: dual semi-implicit algorithm
"
function step_d!(ctx::L1L2TVContext; tau)
# u is P1
# p2 is essentially DP0 (technically may be DP1)
# TODO: this might not be implementable without higher order elements
# need grad(div(p))
return ctx
end
function solve_primal!(u::FeFunction, ctx::L1L2TVContext) function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) = u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) + ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
...@@ -254,6 +359,7 @@ end ...@@ -254,6 +359,7 @@ end
huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2 huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
function estimate!(ctx::L1L2TVContext) function estimate!(ctx::L1L2TVContext)
# FIXME: sign?
function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata) function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * ( alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * (
huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) + huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
...@@ -356,6 +462,13 @@ norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) ...@@ -356,6 +462,13 @@ norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
norm_step(ctx::L1L2TVContext) = norm_step(ctx::L1L2TVContext) =
sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh)) sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh))
function norm_residual(ctx::L1L2TVContext)
w = FeFunction(ctx.u.space)
solve_primal!(w, ctx)
w.data .-= ctx.u.data
return norm_l2(w)
end
function denoise(img; name, params...) function denoise(img; name, params...)
m = 1 m = 1
#mesh = init_grid(img; type=:vertex) #mesh = init_grid(img; type=:vertex)
...@@ -389,10 +502,12 @@ function denoise(img; name, params...) ...@@ -389,10 +502,12 @@ function denoise(img; name, params...)
println() println()
norm_step_ = norm_step(ctx) norm_step_ = norm_step(ctx)
norm_residual_ = norm_residual(ctx)
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))") println("primal energy: $(primal_energy(ctx))")
println("norm_step: $(norm_step_)") println("norm_step: $(norm_step_)")
println("norm_residual: $(norm_residual_)")
norm_step_ <= 1e-1 && break norm_step_ <= 1e-1 && break
end end
...@@ -414,44 +529,67 @@ function denoise_pd(img; name, params...) ...@@ -414,44 +529,67 @@ function denoise_pd(img; name, params...)
mesh = init_grid(img; type=:vertex) mesh = init_grid(img; type=:vertex)
#mesh = init_grid(img, 5, 5) #mesh = init_grid(img, 5, 5)
sigma = 1e-1
tau = 1e-1
theta = 1.
T(tdata, u) = u T(tdata, u) = u
S(u, nablau) = u S(u, nablau) = u
ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
# semi-implicit primal dual parameters
gamma = ctx.alpha2 + ctx.beta # T = I, S = I
sigma = 1e-2
tau = 1e-2
theta = 0.
project_img!(ctx.g, img) project_img!(ctx.g, img)
#interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) #interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
#m = (size(img) .- 1) ./ 2 .+ 1 #m = (size(img) .- 1) ./ 2 .+ 1
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
ctx.u.data .= ctx.g.data
save_denoise(ctx, i) = save_denoise(ctx, i) =
output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est) ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, ctx.du)
pvd = paraview_collection("output/$(ctx.name).pvd") pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_denoise(ctx, 0) pvd[0] = save_denoise(ctx, 0)
k = 0 k = 0
println("primal energy: $(primal_energy(ctx))") println("primal energy: $(primal_energy(ctx))")
algorithm = :newton
while true while true
k += 1 k += 1
step_pd!(ctx; sigma, tau, theta) #println("")
if algorithm == :pd2 #|| primal_energy(ctx) < 2409
println("step_pd: theta = $theta, tau = $tau, sigma = $sigma")
pd = true
theta = 1 / sqrt(1 + 2 * gamma * tau)
tau *= theta
sigma /= theta
step_pd2!(ctx; sigma, tau, theta)
elseif algorithm == :pd1
# no step size control
step_pd!(ctx; sigma, tau, theta)
elseif algorithm == :newton
step!(ctx)
end
#estimate!(ctx) #estimate!(ctx)
pvd[k] = save_denoise(ctx, k) #pvd[k] = save_denoise(ctx, k)
println() println()
norm_step_ = norm_step(ctx) domain_factor = 1 / sqrt(area(mesh))
norm_step_ = norm_step(ctx) * domain_factor
residual = -1.
residual = norm_residual(ctx) * domain_factor
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") #println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))") println("primal energy: $(primal_energy(ctx))")
println("norm_step: $(norm_step_)") println("norm_step: $(norm_step_)")
println("norm_residual: $(residual)")
norm_step_ <= 1e-1 && break #return ctx
k >= 100 && break #residual <= 1e-3 && break
#k >= 5 && break
end end
vtk_save(pvd) vtk_save(pvd)
return ctx return ctx
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment