Skip to content
Snippets Groups Projects
Commit d4d67b01 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

implement primal dual algorithm (untested)

parent fbf2276a
No related branches found
No related tags found
No related merge requests found
...@@ -116,9 +116,12 @@ function assemble(space::FeSpace, a, l; params...) ...@@ -116,9 +116,12 @@ function assemble(space::FeSpace, a, l; params...)
return A, b return A, b
end end
project_img(space::FeSpace, img) =
(u = FeFunction(space); project_img!(u, img))
function project_img(space::FeSpace, img) function project_img!(u::FeFunction, img)
d = 2 # domain dimension d = 2 # domain dimension
space = u.space
mesh = space.mesh mesh = space.mesh
f = ImageFunction(mesh, img) f = ImageFunction(mesh, img)
opparams = (; f) opparams = (; f)
...@@ -131,13 +134,21 @@ function project_img(space::FeSpace, img) ...@@ -131,13 +134,21 @@ function project_img(space::FeSpace, img)
# composite midpoint quadrature on lagrange point lattice # composite midpoint quadrature on lagrange point lattice
function quadrature(p) function quadrature(p)
k = Iterators.filter(x -> sum(x) == p, d_ = 2
Iterators.product((0:p for _ in 1:d+1)...)) |> collect n = binomial(p + 2, 2)
weights = Vector{Float64}(undef, n)
weights = [1 / length(k) for _ in axes(k, 1)] points = Matrix{Float64}(undef, 2, n)
points = [x[i] / p for i in 1:2, x in k]
k = 0
for I in Iterators.product(ntuple(_ -> 0:p, d_ + 1)...)
I[1] + I[2] + I[3] != p && continue
k += 1
weights[k] = 1 / n
points[1, k] = I[1] / p
points[2, k] = I[2] / p
end
return weights::Vector{Float64}, points::Matrix{Float64} return weights, points
end end
I = Float64[] I = Float64[]
...@@ -160,16 +171,16 @@ function project_img(space::FeSpace, img) ...@@ -160,16 +171,16 @@ function project_img(space::FeSpace, img)
qphi = zeros(nrdims, nrdims, nldofs, nqpts) qphi = zeros(nrdims, nrdims, nldofs, nqpts)
dqphi = zeros(nrdims, d, nrdims, nldofs, nqpts) dqphi = zeros(nrdims, d, nrdims, nldofs, nqpts)
for r in 1:nrdims for k in axes(qx, 2)
for k in axes(qx, 2) for r in 1:nrdims
qphi[r, r, :, k] .= evaluate_basis(space.element, qx[:, k]) qphi[r, r, :, k] .= evaluate_basis(space.element, SVector{d}(view(qx, :, k)))
dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(qx[:, k]))) dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(view(qx, :, k))))
end end
end end
# quadrature points # quadrature points
for k in axes(qx, 2) for k in axes(qx, 2)
xhat = SVector{d}(qx[:, k]) xhat = SVector{d}(view(qx, :, k))
x = elmap(mesh, cell)(xhat) x = elmap(mesh, cell)(xhat)
opvalues = map(f -> evaluate(f, xhat), opparams) opvalues = map(f -> evaluate(f, xhat), opparams)
...@@ -204,8 +215,6 @@ function project_img(space::FeSpace, img) ...@@ -204,8 +215,6 @@ function project_img(space::FeSpace, img)
ngdofs = ndofs(space) ngdofs = ndofs(space)
A = sparse(I, J, V, ngdofs, ngdofs) A = sparse(I, J, V, ngdofs, ngdofs)
u = FeFunction(space)
u.data .= A \ b u.data .= A \ b
return u return u
end end
export myrun, denoise, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg export myrun, denoise, denoise_pd, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
using LinearAlgebra: norm using LinearAlgebra: norm
...@@ -169,6 +169,74 @@ function step!(ctx::L1L2TVContext) ...@@ -169,6 +169,74 @@ function step!(ctx::L1L2TVContext)
p2_project!(ctx.p2, ctx.lambda) p2_project!(ctx.p2, ctx.lambda)
end end
function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
# note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0.
error("unsupported parameters")
end
beta = tau * ctx.alpha1 / (1 + 2 * tau * ctx.alpha2)
# u is P1
# p2 is essentially DP0 (technically may be DP1)
# 1.
function p2_update(x_; p2, nablau)
return p2 + sigma * nablau
end
interpolate!(ctx.p2, p2_update; ctx.p2, ctx.nablau)
function p2_project!(p2, lambda)
p2.space.element::DP1
p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
for i in axes(p2d, 2)
p2in = norm(p2d[:, i])
if p2in > lambda
p2d[:, i] .*= lambda ./ p2in
end
end
end
p2_next = FeFunction(ctx.p2.space)
p2_project!(p2_next, ctx.lambda)
ctx.dp2.data .= p2_next.data .- ctx.p2.data
ctx.p2.data .= p2_next.data
# 2.
u_a(x, z, nablaz, phi, nablaphi; g, u, p2) =
dot(z, phi)
u_l(x, phi, nablaphi; u, g, p2) =
(dot(u + 2 * tau * ctx.alpha2 * g, phi) - tau * dot(p2, nablaphi)) /
(1 + 2 * tau * ctx.alpha2)
# z = 1 / (1 + 2 * tau * alpha2) *
# (u + 2 * tau * alpha2 * g + tau * div(p))
z = FeFunction(ctx.u.space)
A, b = assemble(z.space, u_a, u_l; ctx.g, ctx.u, ctx.p2)
z.data .= A \ b
function u_update!(u, z, g, beta)
u.space.element::P1
g.space.element::P1
for i in eachindex(u.data)
if z.data[i] - beta >= g.data[i]
u.data[i] = z.data[i] - beta
elseif z.data[i] + beta <= g.data[i]
u.data[i] = z.data[i] + beta
else
u.data[i] = g.data[i]
end
end
end
u_next = FeFunction(ctx.u.space)
u_update!(u_next, z, ctx.g, beta)
# 3.
ctx.du.data .= u_next.data .- ctx.u.data
ctx.u.data .= u_next.data .+ theta * ctx.du.data
return ctx
end
function solve_primal!(u::FeFunction, ctx::L1L2TVContext) function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) = u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) + ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
...@@ -208,15 +276,20 @@ function estimate!(ctx::L1L2TVContext) ...@@ -208,15 +276,20 @@ function estimate!(ctx::L1L2TVContext)
project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata) project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
end end
function refine(ctx::L1L2TVContext, marked_cells) function refine(ctx::L1L2TVContext, marked_cells; fs_...)
fs = NamedTuple(fs_)
hmesh = HMesh(ctx.mesh) hmesh = HMesh(ctx.mesh)
refined_functions = refine!(hmesh, Set(marked_cells); refined_functions = refine!(hmesh, Set(marked_cells);
ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2) ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2,
fs...)
new_mesh = refined_functions.u.space.mesh new_mesh = refined_functions.u.space.mesh
new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S, new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2) ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs))
@assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap) @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
@assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap) @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
@assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap) @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
...@@ -235,7 +308,7 @@ function refine(ctx::L1L2TVContext, marked_cells) ...@@ -235,7 +308,7 @@ function refine(ctx::L1L2TVContext, marked_cells)
new_ctx.dp1.data .= refined_functions.dp1.data new_ctx.dp1.data .= refined_functions.dp1.data
new_ctx.dp2.data .= refined_functions.dp2.data new_ctx.dp2.data .= refined_functions.dp2.data
return new_ctx return new_ctx, fs_new
end end
function mark(ctx::L1L2TVContext; theta=0.5) function mark(ctx::L1L2TVContext; theta=0.5)
...@@ -280,6 +353,9 @@ end ...@@ -280,6 +353,9 @@ end
norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
norm_step(ctx::L1L2TVContext) =
sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh))
function denoise(img; name, params...) function denoise(img; name, params...)
m = 1 m = 1
#mesh = init_grid(img; type=:vertex) #mesh = init_grid(img; type=:vertex)
...@@ -290,8 +366,9 @@ function denoise(img; name, params...) ...@@ -290,8 +366,9 @@ function denoise(img; name, params...)
ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) project_img!(ctx.g, img)
m = (size(img) .- 1) ./ 2 .+ 1 #interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
#m = (size(img) .- 1) ./ 2 .+ 1
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
save_denoise(ctx, i) = save_denoise(ctx, i) =
...@@ -311,24 +388,69 @@ function denoise(img; name, params...) ...@@ -311,24 +388,69 @@ function denoise(img; name, params...)
pvd[k] = save_denoise(ctx, k) pvd[k] = save_denoise(ctx, k)
println() println()
norm_step = sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(mesh)) norm_step_ = norm_step(ctx)
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))") println("primal energy: $(primal_energy(ctx))")
println("norm_step: $(norm_step)") println("norm_step: $(norm_step_)")
norm_step <= 1e-1 && break norm_step_ <= 1e-1 && break
end end
marked_cells = mark(ctx; theta = 0.5) marked_cells = mark(ctx; theta = 0.5)
#println(marked_cells)
println("refining ...") println("refining ...")
ctx = refine(ctx, marked_cells) ctx, _ = refine(ctx, marked_cells)
test_mesh(ctx.mesh) test_mesh(ctx.mesh)
gnew = project_img(ctx.g.space, img) project_img!(ctx.g, img)
ctx.g.data .= gnew.data
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) k >= 100 && break
end
vtk_save(pvd)
return ctx
end
function denoise_pd(img; name, params...)
m = 1
mesh = init_grid(img; type=:vertex)
#mesh = init_grid(img, 5, 5)
sigma = 1e-1
tau = 1e-1
theta = 1.
T(tdata, u) = u
S(u, nablau) = u
ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
project_img!(ctx.g, img)
#interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
#m = (size(img) .- 1) ./ 2 .+ 1
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
save_denoise(ctx, i) =
output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_denoise(ctx, 0)
k = 0
println("primal energy: $(primal_energy(ctx))")
while true
k += 1
step_pd!(ctx; sigma, tau, theta)
#estimate!(ctx)
pvd[k] = save_denoise(ctx, k)
println()
norm_step_ = norm_step(ctx)
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))")
println("norm_step: $(norm_step_)")
norm_step_ <= 1e-1 && break
k >= 100 && break k >= 100 && break
end end
vtk_save(pvd) vtk_save(pvd)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment