Skip to content
Snippets Groups Projects
Commit d4d67b01 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

implement primal dual algorithm (untested)

parent fbf2276a
Branches
Tags
No related merge requests found
......@@ -116,9 +116,12 @@ function assemble(space::FeSpace, a, l; params...)
return A, b
end
project_img(space::FeSpace, img) =
(u = FeFunction(space); project_img!(u, img))
function project_img(space::FeSpace, img)
function project_img!(u::FeFunction, img)
d = 2 # domain dimension
space = u.space
mesh = space.mesh
f = ImageFunction(mesh, img)
opparams = (; f)
......@@ -131,13 +134,21 @@ function project_img(space::FeSpace, img)
# composite midpoint quadrature on lagrange point lattice
function quadrature(p)
k = Iterators.filter(x -> sum(x) == p,
Iterators.product((0:p for _ in 1:d+1)...)) |> collect
weights = [1 / length(k) for _ in axes(k, 1)]
points = [x[i] / p for i in 1:2, x in k]
d_ = 2
n = binomial(p + 2, 2)
weights = Vector{Float64}(undef, n)
points = Matrix{Float64}(undef, 2, n)
k = 0
for I in Iterators.product(ntuple(_ -> 0:p, d_ + 1)...)
I[1] + I[2] + I[3] != p && continue
k += 1
weights[k] = 1 / n
points[1, k] = I[1] / p
points[2, k] = I[2] / p
end
return weights::Vector{Float64}, points::Matrix{Float64}
return weights, points
end
I = Float64[]
......@@ -160,16 +171,16 @@ function project_img(space::FeSpace, img)
qphi = zeros(nrdims, nrdims, nldofs, nqpts)
dqphi = zeros(nrdims, d, nrdims, nldofs, nqpts)
for r in 1:nrdims
for k in axes(qx, 2)
qphi[r, r, :, k] .= evaluate_basis(space.element, qx[:, k])
dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(qx[:, k])))
for k in axes(qx, 2)
for r in 1:nrdims
qphi[r, r, :, k] .= evaluate_basis(space.element, SVector{d}(view(qx, :, k)))
dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(view(qx, :, k))))
end
end
# quadrature points
for k in axes(qx, 2)
xhat = SVector{d}(qx[:, k])
xhat = SVector{d}(view(qx, :, k))
x = elmap(mesh, cell)(xhat)
opvalues = map(f -> evaluate(f, xhat), opparams)
......@@ -204,8 +215,6 @@ function project_img(space::FeSpace, img)
ngdofs = ndofs(space)
A = sparse(I, J, V, ngdofs, ngdofs)
u = FeFunction(space)
u.data .= A \ b
return u
end
export myrun, denoise, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
export myrun, denoise, denoise_pd, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
using LinearAlgebra: norm
......@@ -169,6 +169,74 @@ function step!(ctx::L1L2TVContext)
p2_project!(ctx.p2, ctx.lambda)
end
function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
# note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0.
error("unsupported parameters")
end
beta = tau * ctx.alpha1 / (1 + 2 * tau * ctx.alpha2)
# u is P1
# p2 is essentially DP0 (technically may be DP1)
# 1.
function p2_update(x_; p2, nablau)
return p2 + sigma * nablau
end
interpolate!(ctx.p2, p2_update; ctx.p2, ctx.nablau)
function p2_project!(p2, lambda)
p2.space.element::DP1
p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
for i in axes(p2d, 2)
p2in = norm(p2d[:, i])
if p2in > lambda
p2d[:, i] .*= lambda ./ p2in
end
end
end
p2_next = FeFunction(ctx.p2.space)
p2_project!(p2_next, ctx.lambda)
ctx.dp2.data .= p2_next.data .- ctx.p2.data
ctx.p2.data .= p2_next.data
# 2.
u_a(x, z, nablaz, phi, nablaphi; g, u, p2) =
dot(z, phi)
u_l(x, phi, nablaphi; u, g, p2) =
(dot(u + 2 * tau * ctx.alpha2 * g, phi) - tau * dot(p2, nablaphi)) /
(1 + 2 * tau * ctx.alpha2)
# z = 1 / (1 + 2 * tau * alpha2) *
# (u + 2 * tau * alpha2 * g + tau * div(p))
z = FeFunction(ctx.u.space)
A, b = assemble(z.space, u_a, u_l; ctx.g, ctx.u, ctx.p2)
z.data .= A \ b
function u_update!(u, z, g, beta)
u.space.element::P1
g.space.element::P1
for i in eachindex(u.data)
if z.data[i] - beta >= g.data[i]
u.data[i] = z.data[i] - beta
elseif z.data[i] + beta <= g.data[i]
u.data[i] = z.data[i] + beta
else
u.data[i] = g.data[i]
end
end
end
u_next = FeFunction(ctx.u.space)
u_update!(u_next, z, ctx.g, beta)
# 3.
ctx.du.data .= u_next.data .- ctx.u.data
ctx.u.data .= u_next.data .+ theta * ctx.du.data
return ctx
end
function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
......@@ -208,15 +276,20 @@ function estimate!(ctx::L1L2TVContext)
project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
end
function refine(ctx::L1L2TVContext, marked_cells)
function refine(ctx::L1L2TVContext, marked_cells; fs_...)
fs = NamedTuple(fs_)
hmesh = HMesh(ctx.mesh)
refined_functions = refine!(hmesh, Set(marked_cells);
ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2)
ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2,
fs...)
new_mesh = refined_functions.u.space.mesh
new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs))
@assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
@assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
@assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
......@@ -235,7 +308,7 @@ function refine(ctx::L1L2TVContext, marked_cells)
new_ctx.dp1.data .= refined_functions.dp1.data
new_ctx.dp2.data .= refined_functions.dp2.data
return new_ctx
return new_ctx, fs_new
end
function mark(ctx::L1L2TVContext; theta=0.5)
......@@ -280,6 +353,9 @@ end
norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
norm_step(ctx::L1L2TVContext) =
sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh))
function denoise(img; name, params...)
m = 1
#mesh = init_grid(img; type=:vertex)
......@@ -290,8 +366,9 @@ function denoise(img; name, params...)
ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
m = (size(img) .- 1) ./ 2 .+ 1
project_img!(ctx.g, img)
#interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
#m = (size(img) .- 1) ./ 2 .+ 1
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
save_denoise(ctx, i) =
......@@ -311,24 +388,69 @@ function denoise(img; name, params...)
pvd[k] = save_denoise(ctx, k)
println()
norm_step = sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(mesh))
norm_step_ = norm_step(ctx)
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))")
println("norm_step: $(norm_step)")
println("norm_step: $(norm_step_)")
norm_step <= 1e-1 && break
norm_step_ <= 1e-1 && break
end
marked_cells = mark(ctx; theta = 0.5)
#println(marked_cells)
println("refining ...")
ctx = refine(ctx, marked_cells)
ctx, _ = refine(ctx, marked_cells)
test_mesh(ctx.mesh)
gnew = project_img(ctx.g.space, img)
ctx.g.data .= gnew.data
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
project_img!(ctx.g, img)
k >= 100 && break
end
vtk_save(pvd)
return ctx
end
function denoise_pd(img; name, params...)
m = 1
mesh = init_grid(img; type=:vertex)
#mesh = init_grid(img, 5, 5)
sigma = 1e-1
tau = 1e-1
theta = 1.
T(tdata, u) = u
S(u, nablau) = u
ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
project_img!(ctx.g, img)
#interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
#m = (size(img) .- 1) ./ 2 .+ 1
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
save_denoise(ctx, i) =
output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_denoise(ctx, 0)
k = 0
println("primal energy: $(primal_energy(ctx))")
while true
k += 1
step_pd!(ctx; sigma, tau, theta)
#estimate!(ctx)
pvd[k] = save_denoise(ctx, k)
println()
norm_step_ = norm_step(ctx)
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))")
println("norm_step: $(norm_step_)")
norm_step_ <= 1e-1 && break
k >= 100 && break
end
vtk_save(pvd)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment