From d4d67b01a7495ba88c8ff67cdcbbdb7da6c8b120 Mon Sep 17 00:00:00 2001 From: Stephan Hilb <stephan@ecshi.net> Date: Sun, 25 Jul 2021 08:27:22 +0200 Subject: [PATCH] implement primal dual algorithm (untested) --- src/operator.jl | 37 +++++++----- src/run.jl | 150 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 159 insertions(+), 28 deletions(-) diff --git a/src/operator.jl b/src/operator.jl index bc27855..b15cd60 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -116,9 +116,12 @@ function assemble(space::FeSpace, a, l; params...) return A, b end +project_img(space::FeSpace, img) = + (u = FeFunction(space); project_img!(u, img)) -function project_img(space::FeSpace, img) +function project_img!(u::FeFunction, img) d = 2 # domain dimension + space = u.space mesh = space.mesh f = ImageFunction(mesh, img) opparams = (; f) @@ -131,13 +134,21 @@ function project_img(space::FeSpace, img) # composite midpoint quadrature on lagrange point lattice function quadrature(p) - k = Iterators.filter(x -> sum(x) == p, - Iterators.product((0:p for _ in 1:d+1)...)) |> collect - - weights = [1 / length(k) for _ in axes(k, 1)] - points = [x[i] / p for i in 1:2, x in k] + d_ = 2 + n = binomial(p + 2, 2) + weights = Vector{Float64}(undef, n) + points = Matrix{Float64}(undef, 2, n) + + k = 0 + for I in Iterators.product(ntuple(_ -> 0:p, d_ + 1)...) + I[1] + I[2] + I[3] != p && continue + k += 1 + weights[k] = 1 / n + points[1, k] = I[1] / p + points[2, k] = I[2] / p + end - return weights::Vector{Float64}, points::Matrix{Float64} + return weights, points end I = Float64[] @@ -160,16 +171,16 @@ function project_img(space::FeSpace, img) qphi = zeros(nrdims, nrdims, nldofs, nqpts) dqphi = zeros(nrdims, d, nrdims, nldofs, nqpts) - for r in 1:nrdims - for k in axes(qx, 2) - qphi[r, r, :, k] .= evaluate_basis(space.element, qx[:, k]) - dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(qx[:, k]))) + for k in axes(qx, 2) + for r in 1:nrdims + qphi[r, r, :, k] .= evaluate_basis(space.element, SVector{d}(view(qx, :, k))) + dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(view(qx, :, k)))) end end # quadrature points for k in axes(qx, 2) - xhat = SVector{d}(qx[:, k]) + xhat = SVector{d}(view(qx, :, k)) x = elmap(mesh, cell)(xhat) opvalues = map(f -> evaluate(f, xhat), opparams) @@ -204,8 +215,6 @@ function project_img(space::FeSpace, img) ngdofs = ndofs(space) A = sparse(I, J, V, ngdofs, ngdofs) - u = FeFunction(space) u.data .= A \ b - return u end diff --git a/src/run.jl b/src/run.jl index a921111..dee7488 100644 --- a/src/run.jl +++ b/src/run.jl @@ -1,4 +1,4 @@ -export myrun, denoise, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg +export myrun, denoise, denoise_pd, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg using LinearAlgebra: norm @@ -169,6 +169,74 @@ function step!(ctx::L1L2TVContext) p2_project!(ctx.p2, ctx.lambda) end +function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.) + # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1! + if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0. + error("unsupported parameters") + end + beta = tau * ctx.alpha1 / (1 + 2 * tau * ctx.alpha2) + + # u is P1 + # p2 is essentially DP0 (technically may be DP1) + + # 1. + function p2_update(x_; p2, nablau) + return p2 + sigma * nablau + end + interpolate!(ctx.p2, p2_update; ctx.p2, ctx.nablau) + + function p2_project!(p2, lambda) + p2.space.element::DP1 + p2d = reshape(p2.data, prod(p2.space.size), :) # no copy + for i in axes(p2d, 2) + p2in = norm(p2d[:, i]) + if p2in > lambda + p2d[:, i] .*= lambda ./ p2in + end + end + end + p2_next = FeFunction(ctx.p2.space) + p2_project!(p2_next, ctx.lambda) + ctx.dp2.data .= p2_next.data .- ctx.p2.data + ctx.p2.data .= p2_next.data + + # 2. + u_a(x, z, nablaz, phi, nablaphi; g, u, p2) = + dot(z, phi) + + u_l(x, phi, nablaphi; u, g, p2) = + (dot(u + 2 * tau * ctx.alpha2 * g, phi) - tau * dot(p2, nablaphi)) / + (1 + 2 * tau * ctx.alpha2) + + # z = 1 / (1 + 2 * tau * alpha2) * + # (u + 2 * tau * alpha2 * g + tau * div(p)) + z = FeFunction(ctx.u.space) + A, b = assemble(z.space, u_a, u_l; ctx.g, ctx.u, ctx.p2) + z.data .= A \ b + + function u_update!(u, z, g, beta) + u.space.element::P1 + g.space.element::P1 + for i in eachindex(u.data) + if z.data[i] - beta >= g.data[i] + u.data[i] = z.data[i] - beta + elseif z.data[i] + beta <= g.data[i] + u.data[i] = z.data[i] + beta + else + u.data[i] = g.data[i] + end + end + end + u_next = FeFunction(ctx.u.space) + u_update!(u_next, z, ctx.g, beta) + + # 3. + ctx.du.data .= u_next.data .- ctx.u.data + ctx.u.data .= u_next.data .+ theta * ctx.du.data + + return ctx +end + function solve_primal!(u::FeFunction, ctx::L1L2TVContext) u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) = ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) + @@ -208,15 +276,20 @@ function estimate!(ctx::L1L2TVContext) project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata) end -function refine(ctx::L1L2TVContext, marked_cells) +function refine(ctx::L1L2TVContext, marked_cells; fs_...) + fs = NamedTuple(fs_) + hmesh = HMesh(ctx.mesh) refined_functions = refine!(hmesh, Set(marked_cells); - ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2) + ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2, + fs...) new_mesh = refined_functions.u.space.mesh new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S, ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2) + fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs)) + @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap) @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap) @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap) @@ -235,7 +308,7 @@ function refine(ctx::L1L2TVContext, marked_cells) new_ctx.dp1.data .= refined_functions.dp1.data new_ctx.dp2.data .= refined_functions.dp2.data - return new_ctx + return new_ctx, fs_new end function mark(ctx::L1L2TVContext; theta=0.5) @@ -280,6 +353,9 @@ end norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) +norm_step(ctx::L1L2TVContext) = + sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh)) + function denoise(img; name, params...) m = 1 #mesh = init_grid(img; type=:vertex) @@ -290,8 +366,9 @@ function denoise(img; name, params...) ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) - interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) - m = (size(img) .- 1) ./ 2 .+ 1 + project_img!(ctx.g, img) + #interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) + #m = (size(img) .- 1) ./ 2 .+ 1 #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) save_denoise(ctx, i) = @@ -311,24 +388,69 @@ function denoise(img; name, params...) pvd[k] = save_denoise(ctx, k) println() - norm_step = sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(mesh)) + norm_step_ = norm_step(ctx) println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") println("primal energy: $(primal_energy(ctx))") - println("norm_step: $(norm_step)") + println("norm_step: $(norm_step_)") - norm_step <= 1e-1 && break + norm_step_ <= 1e-1 && break end marked_cells = mark(ctx; theta = 0.5) - #println(marked_cells) println("refining ...") - ctx = refine(ctx, marked_cells) + ctx, _ = refine(ctx, marked_cells) test_mesh(ctx.mesh) - gnew = project_img(ctx.g.space, img) - ctx.g.data .= gnew.data - #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) + project_img!(ctx.g, img) + + k >= 100 && break + end + vtk_save(pvd) + return ctx +end + +function denoise_pd(img; name, params...) + m = 1 + mesh = init_grid(img; type=:vertex) + #mesh = init_grid(img, 5, 5) + + sigma = 1e-1 + tau = 1e-1 + theta = 1. + + T(tdata, u) = u + S(u, nablau) = u + + ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) + + project_img!(ctx.g, img) + #interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) + #m = (size(img) .- 1) ./ 2 .+ 1 + #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) + + save_denoise(ctx, i) = + output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", + ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est) + + pvd = paraview_collection("output/$(ctx.name).pvd") + pvd[0] = save_denoise(ctx, 0) + + k = 0 + println("primal energy: $(primal_energy(ctx))") + while true + k += 1 + step_pd!(ctx; sigma, tau, theta) + #estimate!(ctx) + pvd[k] = save_denoise(ctx, k) + println() + + norm_step_ = norm_step(ctx) + + println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") + println("primal energy: $(primal_energy(ctx))") + println("norm_step: $(norm_step_)") + norm_step_ <= 1e-1 && break k >= 100 && break end vtk_save(pvd) -- GitLab