diff --git a/scripts/Project.toml b/scripts/Project.toml index bd370fbcd0b515492a8dac03a93239aef58ecfdf..44c2e61b2739ba1bc10e9673c2746db3ecce6329 100644 --- a/scripts/Project.toml +++ b/scripts/Project.toml @@ -1,4 +1,5 @@ [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/scripts/run.jl b/scripts/run.jl index 64f564b2261b44b4171d0b29c0d4dc8a3c894c53..f32f1538cbe81b1f7c2beac5fd47ce7dd498d29c 100644 --- a/scripts/run.jl +++ b/scripts/run.jl @@ -1,801 +1,7 @@ -using LinearAlgebra: norm, dot +include("run_experiments.jl") -using Colors: Gray -# avoid world-age-issues by preloading ColorTypes -import ColorTypes -import DataFrames: DataFrame -import FileIO -using OpticalFlowUtils -using WriteVTK: paraview_collection -using Plots +const datapath = joinpath(@__DIR__, "..", "data") -using SemiSmoothNewton -using SemiSmoothNewton: HMesh, ncells, refine -using SemiSmoothNewton: project_img!, project_img2!, project! -using SemiSmoothNewton: vtk_mesh, vtk_append!, vtk_save +ctx = Util.Context(datapath) -include("util.jl") -isdefined(Main, :Revise) && Revise.track(joinpath(@__DIR__, "util.jl")) - -_toimg(arr) = Gray.(clamp.(arr, 0., 1.)) -#loadimg(x) = reverse(transpose(Float64.(FileIO.load(x))); dims = 2) -#saveimg(io, x) = FileIO.save(io, transpose(reverse(_toimg(x); dims = 2))) -loadimg(x) = Float64.(FileIO.load(x)) -saveimg(io, x) = FileIO.save(io, _toimg(x)) - -from_img(arr) = permutedims(reverse(arr; dims = 1)) -to_img(arr) = permutedims(reverse(arr; dims = 2)) -function to_img(arr::AbstractArray{<:Any,3}) - out = permutedims(reverse(arr; dims = (1,3)), (1, 3, 2)) - out[1, :, :] .= .-out[1, :, :] - return out -end - - -""" - logfilter(dt; a=20) - -Reduce dataset such that the resulting dataset would have approximately -equidistant datapoints on a log-scale. `a` controls density. -""" -logfilter(dt; a=20) = filter(:k => k -> round(a*log(k+1)) - round(a*log(k)) > 0, dt) - -struct L1L2TVContext{M, Ttype, Stype} - name::String - mesh::M - d::Int # = ndims_domain(mesh) - m::Int - - T::Ttype - tdata - S::Stype - - alpha1::Float64 - alpha2::Float64 - beta::Float64 - lambda::Float64 - gamma1::Float64 - gamma2::Float64 - - est::FeFunction - g::FeFunction - u::FeFunction - p1::FeFunction - p2::FeFunction - du::FeFunction - dp1::FeFunction - dp2::FeFunction -end - -function L1L2TVContext(name, mesh, m; T, tdata, S, - alpha1, alpha2, beta, lambda, gamma1, gamma2) - d = ndims_domain(mesh) - - Vest = FeSpace(mesh, DP0(), (1,)) - Vg = FeSpace(mesh, P1(), (1,)) - Vu = FeSpace(mesh, P1(), (m,)) - Vp1 = FeSpace(mesh, DP0(), (1,)) - Vp2 = FeSpace(mesh, DP1(), (m, d)) - - est = FeFunction(Vest, name="est") - g = FeFunction(Vg, name="g") - u = FeFunction(Vu, name="u") - p1 = FeFunction(Vp1, name="p1") - p2 = FeFunction(Vp2, name="p2") - du = FeFunction(Vu; name = "du") - dp1 = FeFunction(Vp1; name = "dp1") - dp2 = FeFunction(Vp2; name = "dp2") - - est.data .= 0 - g.data .= 0 - u.data .= 0 - p1.data .= 0 - p2.data .= 0 - du.data .= 0 - dp1.data .= 0 - dp2.data .= 0 - - return L1L2TVContext(name, mesh, d, m, T, tdata, S, - alpha1, alpha2, beta, lambda, gamma1, gamma2, - est, g, u, p1, p2, du, dp1, dp2) -end - -function p1_project!(p1, alpha1) - p1.space.element == DP0() || p1.space.element == P1() || - p1.space.element == DP1() || - throw(ArgumentError("element unsupported")) - p1.data .= clamp.(p1.data, -alpha1, alpha1) -end - -function p2_project!(p2, lambda) - p2.space.element::DP1 - p2d = reshape(p2.data, prod(p2.space.size), :) # no copy - for i in axes(p2d, 2) - p2in = norm(p2d[:, i]) - if p2in > lambda - p2d[:, i] .*= lambda ./ p2in - end - end -end - -function step!(ctx::L1L2TVContext) - T = ctx.T - S = ctx.S - alpha1 = ctx.alpha1 - alpha2 = ctx.alpha2 - beta = ctx.beta - lambda = ctx.lambda - gamma1 = ctx.gamma1 - gamma2 = ctx.gamma2 - - function du_a(x_, du, nabladu, phi, nablaphi; g, u, nablau, p1, p2, tdata) - m1 = max(gamma1, norm(T(tdata, u) - g)) - cond1 = norm(T(tdata, u) - g) > gamma1 ? - dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : - zero(p1) - a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) - - dot(cond1, T(tdata, phi)) - - m2 = max(gamma2, norm(nablau)) - cond2 = norm(nablau) > gamma2 ? - dot(nablau, nabladu) / norm(nablau)^2 * p2 : - zero(p2) - a2 = lambda / m2 * dot(nabladu, nablaphi) - - dot(cond2, nablaphi) - - aB = alpha2 * dot(T(tdata, du), T(tdata, phi)) + - beta * dot(S(du, nabladu), S(phi, nablaphi)) - - return a1 + a2 + aB - end - - function du_l(x_, phi, nablaphi; g, u, nablau, p1, p2, tdata) - aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) + - beta * dot(S(u, nablau), S(phi, nablaphi)) - m1 = max(gamma1, norm(T(tdata, u) - g)) - p1part = alpha1 / m1 * dot(T(tdata, u) - g, T(tdata, phi)) - m2 = max(gamma2, norm(nablau)) - p2part = lambda / m2 * dot(nablau, nablaphi) - gpart = alpha2 * dot(g, T(tdata, phi)) - - return -aB - p1part - p2part + gpart - end - - # solve du - print("assemble ... ") - A, b = assemble(ctx.du.space, du_a, du_l; - ctx.g, ctx.u, nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata) - print("solve ... ") - ctx.du.data .= A \ b - - - # solve dp1 - function dp1_update(x_; g, u, p1, du, tdata) - m1 = max(gamma1, norm(T(tdata, u) - g)) - cond = norm(T(tdata, u) - g) > gamma1 ? - dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : - zero(p1) - return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond - end - interpolate!(ctx.dp1, dp1_update; - ctx.g, ctx.u, ctx.p1, ctx.du, ctx.tdata) - - # solve dp2 - function dp2_update(x_; u, nablau, p2, du, nabladu) - m2 = max(gamma2, norm(nablau)) - cond = norm(nablau) > gamma2 ? - dot(nablau, nabladu) / norm(nablau)^2 * p2 : - zero(p2) - return -p2 + lambda / m2 * (nablau + nabladu) - cond - end - interpolate!(ctx.dp2, dp2_update; - ctx.u, nablau = nabla(ctx.u), ctx.p2, ctx.du, nabladu = nabla(ctx.du)) - - # newton update - theta = 1. - ctx.u.data .+= theta * ctx.du.data - ctx.p1.data .+= theta * ctx.dp1.data - ctx.p2.data .+= theta * ctx.dp2.data - - # reproject p1, p2 - p1_project!(ctx.p1, ctx.alpha1) - p2_project!(ctx.p2, ctx.lambda) -end - -" -2010, Chambolle and Pock: primal-dual semi-implicit algorithm -2017, Alkämper and Langer: fem dualisation -" -function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.) - # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1! - # changed alpha2 -> alpha2 / 2 - # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad| - if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0. - error("unsupported parameters") - end - beta = tau * ctx.alpha1 / (1 + tau * ctx.alpha2) - - # u is P1 - # p2 is essentially DP0 (technically may be DP1) - - # 1. y update - u_new = FeFunction(ctx.u.space) # x_bar only used here - u_new.data .= ctx.u.data .+ theta .* ctx.du.data - - p2_next = FeFunction(ctx.p2.space) - - function p2_update(x_; p2, nablau) - return p2 + sigma * nablau - end - interpolate!(p2_next, p2_update; ctx.p2, nablau = nabla(u_new)) - - p2_project!(p2_next, ctx.lambda) - ctx.dp2.data .= p2_next.data .- ctx.p2.data - - ctx.p2.data .+= ctx.dp2.data - - # 2. - u_a(x, z, nablaz, phi, nablaphi; g, u, p2) = - dot(z, phi) - - u_l(x, phi, nablaphi; u, g, p2) = - (dot(u + tau * ctx.alpha2 * g, phi) + tau * dot(p2, -nablaphi)) / - (1 + tau * ctx.alpha2) - - # z = 1 / (1 + tau * alpha2) * - # (u + tau * alpha2 * g + tau * div(p)) - z = FeFunction(ctx.u.space) - A, b = assemble(z.space, u_a, u_l; ctx.g, ctx.u, ctx.p2) - z.data .= A \ b - - function u_update!(u, z, g, beta) - u.space.element::P1 - g.space.element::P1 - for i in eachindex(u.data) - if z.data[i] - beta >= g.data[i] - u.data[i] = z.data[i] - beta - elseif z.data[i] + beta <= g.data[i] - u.data[i] = z.data[i] + beta - else - u.data[i] = g.data[i] - end - end - end - u_update!(u_new, z, ctx.g, beta) - - # 3. - # note: step-size control not implemented, since \nabla G^* is not 1/gamma continuous - ctx.du.data .= u_new.data .- ctx.u.data - ctx.u.data .= u_new.data - - #ctx.du.data .= z.data - any(isnan, ctx.u.data) && error("encountered nan data") - any(isnan, ctx.p2.data) && error("encountered nan data") - - return ctx -end - -" -2010, Chambolle and Pock: accelerated primal-dual semi-implicit algorithm -" -function step_pd2!(ctx::L1L2TVContext; sigma, tau, theta = 1.) - # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad| - - # u is P1 - # p2 is essentially DP0 (technically may be DP1) - - # 1. y update - u_new = FeFunction(ctx.u.space) # x_bar only used here - u_new.data .= ctx.u.data .+ theta .* ctx.du.data - - p1_next = FeFunction(ctx.p1.space) - p2_next = FeFunction(ctx.p2.space) - - function p1_update(x_; p1, g, u, tdata) - ctx.alpha1 == 0 && return zero(p1) - return (p1 + sigma * (ctx.T(tdata, u) - g)) / - (1 + ctx.gamma1 * sigma / ctx.alpha1) - end - interpolate!(p1_next, p1_update; ctx.p1, ctx.g, ctx.u, ctx.tdata) - - function p2_update(x_; p2, nablau) - ctx.lambda == 0 && return zero(p2) - return (p2 + sigma * nablau) / - (1 + ctx.gamma2 * sigma / ctx.lambda) - end - interpolate!(p2_next, p2_update; ctx.p2, nablau = nabla(u_new)) - - # reproject p1, p2 - p1_project!(p1_next, ctx.alpha1) - p2_project!(p2_next, ctx.lambda) - - ctx.dp1.data .= p1_next.data .- ctx.p1.data - ctx.dp2.data .= p2_next.data .- ctx.p2.data - ctx.p1.data .+= ctx.dp1.data - ctx.p2.data .+= ctx.dp2.data - - # 2. x update - u_a(x, w, nablaw, phi, nablaphi; g, u, p1, p2, tdata) = - dot(w, phi) + - tau * ctx.alpha2 * dot(ctx.T(tdata, w), ctx.T(tdata, phi)) + - tau * ctx.beta * dot(ctx.S(w, nablaw), ctx.S(phi, nablaphi)) - - u_l(x, phi, nablaphi; g, u, p1, p2, tdata) = - dot(u, phi) - tau * ( - dot(p1, ctx.T(tdata, phi)) + - dot(p2, nablaphi) - - ctx.alpha2 * dot(g, ctx.T(tdata, phi))) - - A, b = assemble(u_new.space, u_a, u_l; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.tdata) - u_new.data .= A \ b - ctx.du.data .= u_new.data .- ctx.u.data - ctx.u.data .= u_new.data - - #ctx.du.data .= z.data - any(isnan, ctx.u.data) && error("encountered nan data") - any(isnan, ctx.p1.data) && error("encountered nan data") - any(isnan, ctx.p2.data) && error("encountered nan data") - - return ctx -end - -" -2004, Chambolle: dual semi-implicit algorithm -" -function step_d!(ctx::L1L2TVContext; tau) - # u is P1 - # p2 is essentially DP0 (technically may be DP1) - - # TODO: this might not be implementable without higher order elements - # need grad(div(p)) - return ctx -end - -function solve_primal!(u::FeFunction, ctx::L1L2TVContext) - u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) = - ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) + - ctx.beta * dot(ctx.S(u, nablau), ctx.S(phi, nablaphi)) - - u_l(x, phi, nablaphi; g, p1, p2, tdata) = - -dot(p1, ctx.T(tdata, phi)) - dot(p2, nablaphi) + - ctx.alpha2 * dot(g, ctx.T(tdata, phi)) - - # u = B^{-1} * (T^* p_1 - div p_2 - alpha2 * T^* g) - A, b = assemble(u.space, u_a, u_l; ctx.g, ctx.p1, ctx.p2, ctx.tdata) - u.data .= A \ b -end - -huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2 - -function estimate!(ctx::L1L2TVContext) - # FIXME: sign? - function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata) - alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * ( - huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) - - dot(ctx.T(tdata, u) - g, p1 / ctx.alpha1) + - ctx.gamma1 / 2 * norm(p1 / ctx.alpha1)^2) - lambdapart = iszero(ctx.lambda) ? 0. : ctx.lambda * ( - huber(norm(nablau), ctx.gamma2) - - dot(nablau, p2 / ctx.lambda) + - ctx.gamma2 / 2 * norm(p2 / ctx.lambda)^2) - bpart = 1 / 2 * ( - ctx.alpha2 * dot(ctx.T(tdata, w - u), ctx.T(tdata, w - u)) + - ctx.beta * dot(ctx.S(w, nablaw) - ctx.S(u, nablau), ctx.S(w, nablaw) - ctx.S(u, nablau))) - - return alpha1part + lambdapart + bpart - end - - w = FeFunction(ctx.u.space) - solve_primal!(w, ctx) - project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, - nablau = nabla(ctx.u), w, nablaw = nabla(w), ctx.tdata) -end - -# TODO: deprecate in favor of refine(mesh, marked_cells; fs...) -#function refine(ctx::L1L2TVContext, marked_cells; fs_...) -# fs = NamedTuple(fs_) -# -# hmesh = HMesh(ctx.mesh) -# refined_functions = refine!(hmesh, Set(marked_cells); -# ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2, -# fs...) -# new_mesh = refined_functions.u.space.mesh -# -# # TODO: tdata needs to be recreated for refinement -# new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S, -# ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2) -# -# fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs)) -# -# @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap) -# @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap) -# @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap) -# @assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap) -# @assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap) -# @assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap) -# @assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap) -# @assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap) -# -# new_ctx.est.data .= refined_functions.est.data -# new_ctx.g.data .= refined_functions.g.data -# new_ctx.u.data .= refined_functions.u.data -# new_ctx.p1.data .= refined_functions.p1.data -# new_ctx.p2.data .= refined_functions.p2.data -# new_ctx.du.data .= refined_functions.du.data -# new_ctx.dp1.data .= refined_functions.dp1.data -# new_ctx.dp2.data .= refined_functions.dp2.data -# -# return new_ctx, fs_new -#end - -# minimal Dörfler marking -function mark(ctx::L1L2TVContext; theta=0.5) - n = ncells(ctx.mesh) - esttotal = sum(ctx.est.data) - - cellerrors = collect(pairs(ctx.est.data)) - cellerrors_sorted = sort(cellerrors; lt = (x, y) -> x.second > y.second) - - marked_cells = Int[] - estacc = 0. - for (cell, error) in cellerrors_sorted - estacc >= theta * esttotal && break - push!(marked_cells, cell) - estacc += error - end - return marked_cells -end - - -function output(ctx::L1L2TVContext, filename, fs...) - print("save \"$filename\" ... ") - vtk = vtk_mesh(filename, ctx.mesh) - vtk_append!(vtk, fs...) - vtk_save(vtk) - return vtk -end - -function primal_energy(ctx::L1L2TVContext) - function integrand(x; g, u, nablau, tdata) - return ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) + - ctx.alpha2 / 2 * norm(ctx.T(tdata, u) - g)^2 + - ctx.beta / 2 * norm(ctx.S(u, nablau))^2 + - ctx.lambda * huber(norm(nablau), ctx.gamma2) - end - return integrate(ctx.mesh, integrand; ctx.g, ctx.u, - nablau = nabla(ctx.u), ctx.tdata) -end - -norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) - -norm_step(ctx::L1L2TVContext) = - sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh)) - -function norm_residual(ctx::L1L2TVContext) - w = FeFunction(ctx.u.space) - solve_primal!(w, ctx) - w.data .-= ctx.u.data - upart2 = norm_l2(w)^2 - - function integrand(x; g, u, nablau, p1, p2, tdata) - p1part = p1 * max(ctx.gamma1, norm(ctx.T(tdata, u) - g)) - - ctx.alpha1 * (ctx.T(tdata, u) - g) - p2part = p2 * max(ctx.gamma2, norm(nablau)) - - ctx.lambda * nablau - return norm(p1part)^2 + norm(p2part)^2 - end - ppart2 = integrate(ctx.mesh, integrand; ctx.g, ctx.u, - nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata) - - return sqrt(upart2 + ppart2) -end - -function denoise(img; name, params...) - m = 1 - img = from_img(img) # coord flip - #mesh = init_grid(img; type=:vertex) - mesh = init_grid(img, 5, 5) - - T(tdata, u) = u - S(u, nablau) = u - - ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) - - project_img!(ctx.g, img) - #interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) - #m = (size(img) .- 1) ./ 2 .+ 1 - #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) - - save_denoise(ctx, i) = - output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", - ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est) - - pvd = paraview_collection("output/$(ctx.name).pvd") - pvd[0] = save_denoise(ctx, 0) - - k = 0 - println("primal energy: $(primal_energy(ctx))") - while true - while true - k += 1 - step!(ctx) - estimate!(ctx) - pvd[k] = save_denoise(ctx, k) - println() - - norm_step_ = norm_step(ctx) - norm_residual_ = norm_residual(ctx) - - println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") - println("primal energy: $(primal_energy(ctx))") - println("norm_step: $(norm_step_)") - println("norm_residual: $(norm_residual_)") - - norm_step_ <= 1e-1 && break - end - marked_cells = mark(ctx; theta = 0.5) - println("refining ...") - ctx, _ = refine(ctx, marked_cells) - test_mesh(ctx.mesh) - - project_img!(ctx.g, img) - - k >= 100 && break - end - vtk_save(pvd) - return ctx -end - - -function denoise_pd(img; df=nothing, name, algorithm, params...) - m = 1 - img = from_img(img) # coord flip - mesh = init_grid(img; type=:vertex) - #mesh = init_grid(img, 5, 5) - - T(tdata, u) = u - S(u, nablau) = u - - ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) - - # semi-implicit primal dual parameters - gamma = ctx.alpha2 + ctx.beta # T = I, S = I - gamma /= 100 # kind of arbitrary? - - tau = 1e-1 - L = 100 - sigma = inv(tau * L^2) - theta = 1. - - #project_img!(ctx.g, img) - interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) - ctx.u.data .= ctx.g.data - - save_denoise(ctx, i) = - output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", - ctx.g, ctx.u, ctx.p1, ctx.p2) - - log!(x::Nothing; kwargs...) = x - function log!(df::DataFrame; k, norm_step, norm_residual) - push!(df, (; - k, - primal_energy = primal_energy(ctx), - norm_step, - norm_residual)) - println(NamedTuple(last(df))) - end - - pvd = paraview_collection("output/$(ctx.name).pvd") - pvd[0] = save_denoise(ctx, 0) - - k = 0 - println("primal energy: $(primal_energy(ctx))") - - while true - k += 1 - if algorithm == :pd1 - # no step size control - step_pd2!(ctx; sigma, tau, theta) - elseif algorithm == :pd2 - theta = 1 / sqrt(1 + 2 * gamma * tau) - tau *= theta - sigma /= theta - step_pd2!(ctx; sigma, tau, theta) - elseif algorithm == :newton - step!(ctx) - end - #pvd[k] = save_denoise(ctx, k) - - domain_factor = 1 / sqrt(area(mesh)) - norm_step_ = norm_step(ctx) * domain_factor - norm_residual_ = norm_residual(ctx) * domain_factor - - log!(df; k, norm_step = norm_step_, norm_residual = norm_residual_) - - norm_residual_ < 1e-6 && norm_step_ < 1e-6 && break - norm_step_ < 1e-13 && break - end - pvd[1] = save_denoise(ctx, 1) - vtk_save(pvd) - return ctx -end - -function experiment1(img) - path = "data/pd-comparison" - - img = loadimg("data/denoising/input.png") - img = from_img(img) # coord flip - - algparams = (alpha1=0., alpha2=30., lambda=1., beta=0., gamma1=1e-3, gamma2=1e-3) - - df1 = DataFrame() - df2 = DataFrame() - df3 = DataFrame() - - denoise_pd(img; name="test", algorithm=:pd1, df = df1, algparams...); - denoise_pd(img; name="test", algorithm=:pd2, df = df2, algparams...); - denoise_pd(img; name="test", algorithm=:newton, df = df3, algparams...); - - energy_min = min(minimum(df1.primal_energy), minimum(df2.primal_energy), - minimum(df3.primal_energy)) - - df1.primal_energy .-= energy_min - df2.primal_energy .-= energy_min - df3.primal_energy .-= energy_min - - CSV.write(joinpath(path, "semiimplicit.csv"), logfilter(df1)) - CSV.write(joinpath(path, "semiimplicit-accelerated.csv"), logfilter(df2)) - CSV.write(joinpath(path, "newton.csv"), logfilter(df3)) -end - -function inpaint(img, imgmask; name, params...) - size(img) == size(imgmask) || - throw(ArgumentError("non-matching dimensions")) - - m = 1 - img = from_img(img) # coord flip - imgmask = from_img(imgmask) # coord flip - mesh = init_grid(img; type=:vertex) - - # inpaint specific stuff - Vg = FeSpace(mesh, P1(), (1,)) - mask = FeFunction(Vg, name="mask") - - T(tdata, u) = isone(tdata[begin]) ? u : zero(u) - S(u, nablau) = u - - ctx = L1L2TVContext(name, mesh, m; T, tdata = mask, S, params...) - - # FIXME: currently dual grid only - interpolate!(mask, x -> imgmask[round.(Int, x)...]) - #interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1) - interpolate!(ctx.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] : 0.) - m = (size(img) .- 1) ./ 2 .+ 1 - interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3) - - save_inpaint(i) = - output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", - ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask) - - pvd = paraview_collection("output/$(ctx.name).pvd") - pvd[0] = save_inpaint(0) - for i in 1:3 - step!(ctx) - estimate!(ctx) - pvd[i] = save_inpaint(i) - println() - end - return ctx -end - -function optflow(ctx) - # coord flip - imgf0 = from_img(ctx.params.imgf0) - imgf1 = from_img(ctx.params.imgf1) - maxflow = 4.6157 # Dimetrodon - size(imgf0) == size(imgf1) || - throw(ArgumentError("non-matching dimensions")) - - m = 2 - #mesh = init_grid(imgf0; type=:vertex) - mesh = init_grid(imgf0, (size(imgf0) .÷ 16)...) - #mesh = init_grid(imgf0) - - # optflow specific stuff - Vg = FeSpace(mesh, P1(), (1,)) - f0 = FeFunction(Vg, name="f0") - f1 = FeFunction(Vg, name="f1") - fw = FeFunction(Vg, name="fw") - - T(tdata, u) = tdata * u # tdata = nablafw - S(u, nablau) = nablau - #S(u, nablau) = u - - st = L1L2TVContext("run", mesh, m; T, tdata = nabla(fw), S, - ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta, - ctx.params.gamma1, ctx.params.gamma2) - - function warp!() - imgfw = warp_backwards(imgf1, sample(st.u)) - project_img!(fw, imgfw) - - # replace new tdata - st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fw), S, - st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2, - st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2) - - g_optflow(x; u, f0, fw, nablafw) = - nablafw * u - (fw - f0) - interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw = st.tdata) - end - - function reproject!() - project_img2!(f0, imgf0) - project_img2!(f1, imgf1) - end - reproject!() - warp!() - - save_step(i) = - output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"), - st.g, st.u, st.p1, st.p2, st.est, f0, f1, fw) - - i = 0 - pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd - pvd[i] = save_step(i) - while true - for j in 1:4 - norm_g_old = norm_l2(st.g) - for k in 1:5 - i += 1 - step!(st) - estimate!(st) - pvd[i] = save_step(i) - println() - end - warp!() - norm_g = norm_l2(st.g) - i += 1 - pvd[i] = save_step(i) - display(plot(colorflow(to_img(sample(st.u)); maxflow))) - end - i >= 50 && break - #continue - - marked_cells = mark(st; theta = 0.5) - #marked_cells = Set(axes(mesh.cells, 2)) - - println("refining ...") - - mesh, fs = refine(mesh, marked_cells; - st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2, - f0, f1, fw) - st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fs.fw), S, - st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2, - fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2) - f0, f1, fw = (fs.f0, fs.f1, fs.fw) - i += 1 - pvd[i] = save_step(i) - - println("reprojecting ...") - reproject!() - i += 1 - pvd[i] = save_step(i) - end - end - display(plot(colorflow(to_img(sample(st.u)); maxflow))) - - #CSV.write(joinpath(ctx.outdir, "energies.csv"), df) - #saveimg(joinpath(ctx.outdir, "output_glob.png"), fetch_u(states.glob)) - #savedata(ctx, "data.tex"; lambda=λ, beta=β, tau=τ, maxiters, energymin, - # width=size(fo, 2), height=size(fo, 1)) - return st -end - -function experiment_optflow_middlebury(ctx) - imgf0 = loadimg(joinpath(ctx.indir, "frame10.png")) - imgf1 = loadimg(joinpath(ctx.indir, "frame11.png")) - - ctx = Util.Context(ctx; imgf0, imgf1) - return optflow(ctx) -end +ctx(experiment_pd-comparison, "pd-comparison") diff --git a/scripts/run_experiments.jl b/scripts/run_experiments.jl new file mode 100644 index 0000000000000000000000000000000000000000..dfc7dbd4e0dc63def33f71add614e446177974ee --- /dev/null +++ b/scripts/run_experiments.jl @@ -0,0 +1,812 @@ +using LinearAlgebra: norm, dot + +using Colors: Gray +# avoid world-age-issues by preloading ColorTypes +import ColorTypes +import CSV +import DataFrames: DataFrame +import FileIO +using OpticalFlowUtils +using WriteVTK: paraview_collection +using Plots + +using SemiSmoothNewton +using SemiSmoothNewton: HMesh, ncells, refine, area +using SemiSmoothNewton: project_img!, project_img2!, project! +using SemiSmoothNewton: vtk_mesh, vtk_append!, vtk_save + +include("util.jl") +isdefined(Main, :Revise) && Revise.track(joinpath(@__DIR__, "util.jl")) + +using .Util + +_toimg(arr) = Gray.(clamp.(arr, 0., 1.)) +#loadimg(x) = reverse(transpose(Float64.(FileIO.load(x))); dims = 2) +#saveimg(io, x) = FileIO.save(io, transpose(reverse(_toimg(x); dims = 2))) +loadimg(x) = Float64.(FileIO.load(x)) +saveimg(io, x) = FileIO.save(io, _toimg(x)) + +from_img(arr) = permutedims(reverse(arr; dims = 1)) +to_img(arr) = permutedims(reverse(arr; dims = 2)) +function to_img(arr::AbstractArray{<:Any,3}) + out = permutedims(reverse(arr; dims = (1,3)), (1, 3, 2)) + out[1, :, :] .= .-out[1, :, :] + return out +end + + +""" + logfilter(dt; a=20) + +Reduce dataset such that the resulting dataset would have approximately +equidistant datapoints on a log-scale. `a` controls density. +""" +logfilter(dt; a=20) = filter(:k => k -> round(a*log(k+1)) - round(a*log(k)) > 0, dt) + +struct L1L2TVContext{M, Ttype, Stype} + name::String + mesh::M + d::Int # = ndims_domain(mesh) + m::Int + + T::Ttype + tdata + S::Stype + + alpha1::Float64 + alpha2::Float64 + beta::Float64 + lambda::Float64 + gamma1::Float64 + gamma2::Float64 + + est::FeFunction + g::FeFunction + u::FeFunction + p1::FeFunction + p2::FeFunction + du::FeFunction + dp1::FeFunction + dp2::FeFunction +end + +function L1L2TVContext(name, mesh, m; T, tdata, S, + alpha1, alpha2, beta, lambda, gamma1, gamma2) + d = ndims_domain(mesh) + + Vest = FeSpace(mesh, DP0(), (1,)) + Vg = FeSpace(mesh, P1(), (1,)) + Vu = FeSpace(mesh, P1(), (m,)) + Vp1 = FeSpace(mesh, DP0(), (1,)) + Vp2 = FeSpace(mesh, DP1(), (m, d)) + + est = FeFunction(Vest, name="est") + g = FeFunction(Vg, name="g") + u = FeFunction(Vu, name="u") + p1 = FeFunction(Vp1, name="p1") + p2 = FeFunction(Vp2, name="p2") + du = FeFunction(Vu; name = "du") + dp1 = FeFunction(Vp1; name = "dp1") + dp2 = FeFunction(Vp2; name = "dp2") + + est.data .= 0 + g.data .= 0 + u.data .= 0 + p1.data .= 0 + p2.data .= 0 + du.data .= 0 + dp1.data .= 0 + dp2.data .= 0 + + return L1L2TVContext(name, mesh, d, m, T, tdata, S, + alpha1, alpha2, beta, lambda, gamma1, gamma2, + est, g, u, p1, p2, du, dp1, dp2) +end + +function p1_project!(p1, alpha1) + p1.space.element == DP0() || p1.space.element == P1() || + p1.space.element == DP1() || + throw(ArgumentError("element unsupported")) + p1.data .= clamp.(p1.data, -alpha1, alpha1) +end + +function p2_project!(p2, lambda) + p2.space.element::DP1 + p2d = reshape(p2.data, prod(p2.space.size), :) # no copy + for i in axes(p2d, 2) + p2in = norm(p2d[:, i]) + if p2in > lambda + p2d[:, i] .*= lambda ./ p2in + end + end +end + +function step!(ctx::L1L2TVContext) + T = ctx.T + S = ctx.S + alpha1 = ctx.alpha1 + alpha2 = ctx.alpha2 + beta = ctx.beta + lambda = ctx.lambda + gamma1 = ctx.gamma1 + gamma2 = ctx.gamma2 + + function du_a(x_, du, nabladu, phi, nablaphi; g, u, nablau, p1, p2, tdata) + m1 = max(gamma1, norm(T(tdata, u) - g)) + cond1 = norm(T(tdata, u) - g) > gamma1 ? + dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : + zero(p1) + a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) - + dot(cond1, T(tdata, phi)) + + m2 = max(gamma2, norm(nablau)) + cond2 = norm(nablau) > gamma2 ? + dot(nablau, nabladu) / norm(nablau)^2 * p2 : + zero(p2) + a2 = lambda / m2 * dot(nabladu, nablaphi) - + dot(cond2, nablaphi) + + aB = alpha2 * dot(T(tdata, du), T(tdata, phi)) + + beta * dot(S(du, nabladu), S(phi, nablaphi)) + + return a1 + a2 + aB + end + + function du_l(x_, phi, nablaphi; g, u, nablau, p1, p2, tdata) + aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) + + beta * dot(S(u, nablau), S(phi, nablaphi)) + m1 = max(gamma1, norm(T(tdata, u) - g)) + p1part = alpha1 / m1 * dot(T(tdata, u) - g, T(tdata, phi)) + m2 = max(gamma2, norm(nablau)) + p2part = lambda / m2 * dot(nablau, nablaphi) + gpart = alpha2 * dot(g, T(tdata, phi)) + + return -aB - p1part - p2part + gpart + end + + # solve du + print("assemble ... ") + A, b = assemble(ctx.du.space, du_a, du_l; + ctx.g, ctx.u, nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata) + print("solve ... ") + ctx.du.data .= A \ b + + + # solve dp1 + function dp1_update(x_; g, u, p1, du, tdata) + m1 = max(gamma1, norm(T(tdata, u) - g)) + cond = norm(T(tdata, u) - g) > gamma1 ? + dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : + zero(p1) + return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond + end + interpolate!(ctx.dp1, dp1_update; + ctx.g, ctx.u, ctx.p1, ctx.du, ctx.tdata) + + # solve dp2 + function dp2_update(x_; u, nablau, p2, du, nabladu) + m2 = max(gamma2, norm(nablau)) + cond = norm(nablau) > gamma2 ? + dot(nablau, nabladu) / norm(nablau)^2 * p2 : + zero(p2) + return -p2 + lambda / m2 * (nablau + nabladu) - cond + end + interpolate!(ctx.dp2, dp2_update; + ctx.u, nablau = nabla(ctx.u), ctx.p2, ctx.du, nabladu = nabla(ctx.du)) + + # newton update + theta = 1. + ctx.u.data .+= theta * ctx.du.data + ctx.p1.data .+= theta * ctx.dp1.data + ctx.p2.data .+= theta * ctx.dp2.data + + # reproject p1, p2 + p1_project!(ctx.p1, ctx.alpha1) + p2_project!(ctx.p2, ctx.lambda) +end + +" +2010, Chambolle and Pock: primal-dual semi-implicit algorithm +2017, Alkämper and Langer: fem dualisation +" +function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.) + # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1! + # changed alpha2 -> alpha2 / 2 + # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad| + if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0. + error("unsupported parameters") + end + beta = tau * ctx.alpha1 / (1 + tau * ctx.alpha2) + + # u is P1 + # p2 is essentially DP0 (technically may be DP1) + + # 1. y update + u_new = FeFunction(ctx.u.space) # x_bar only used here + u_new.data .= ctx.u.data .+ theta .* ctx.du.data + + p2_next = FeFunction(ctx.p2.space) + + function p2_update(x_; p2, nablau) + return p2 + sigma * nablau + end + interpolate!(p2_next, p2_update; ctx.p2, nablau = nabla(u_new)) + + p2_project!(p2_next, ctx.lambda) + ctx.dp2.data .= p2_next.data .- ctx.p2.data + + ctx.p2.data .+= ctx.dp2.data + + # 2. + u_a(x, z, nablaz, phi, nablaphi; g, u, p2) = + dot(z, phi) + + u_l(x, phi, nablaphi; u, g, p2) = + (dot(u + tau * ctx.alpha2 * g, phi) + tau * dot(p2, -nablaphi)) / + (1 + tau * ctx.alpha2) + + # z = 1 / (1 + tau * alpha2) * + # (u + tau * alpha2 * g + tau * div(p)) + z = FeFunction(ctx.u.space) + A, b = assemble(z.space, u_a, u_l; ctx.g, ctx.u, ctx.p2) + z.data .= A \ b + + function u_update!(u, z, g, beta) + u.space.element::P1 + g.space.element::P1 + for i in eachindex(u.data) + if z.data[i] - beta >= g.data[i] + u.data[i] = z.data[i] - beta + elseif z.data[i] + beta <= g.data[i] + u.data[i] = z.data[i] + beta + else + u.data[i] = g.data[i] + end + end + end + u_update!(u_new, z, ctx.g, beta) + + # 3. + # note: step-size control not implemented, since \nabla G^* is not 1/gamma continuous + ctx.du.data .= u_new.data .- ctx.u.data + ctx.u.data .= u_new.data + + #ctx.du.data .= z.data + any(isnan, ctx.u.data) && error("encountered nan data") + any(isnan, ctx.p2.data) && error("encountered nan data") + + return ctx +end + +" +2010, Chambolle and Pock: accelerated primal-dual semi-implicit algorithm +" +function step_pd2!(ctx::L1L2TVContext; sigma, tau, theta = 1.) + # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad| + + # u is P1 + # p2 is essentially DP0 (technically may be DP1) + + # 1. y update + u_new = FeFunction(ctx.u.space) # x_bar only used here + u_new.data .= ctx.u.data .+ theta .* ctx.du.data + + p1_next = FeFunction(ctx.p1.space) + p2_next = FeFunction(ctx.p2.space) + + function p1_update(x_; p1, g, u, tdata) + ctx.alpha1 == 0 && return zero(p1) + return (p1 + sigma * (ctx.T(tdata, u) - g)) / + (1 + ctx.gamma1 * sigma / ctx.alpha1) + end + interpolate!(p1_next, p1_update; ctx.p1, ctx.g, ctx.u, ctx.tdata) + + function p2_update(x_; p2, nablau) + ctx.lambda == 0 && return zero(p2) + return (p2 + sigma * nablau) / + (1 + ctx.gamma2 * sigma / ctx.lambda) + end + interpolate!(p2_next, p2_update; ctx.p2, nablau = nabla(u_new)) + + # reproject p1, p2 + p1_project!(p1_next, ctx.alpha1) + p2_project!(p2_next, ctx.lambda) + + ctx.dp1.data .= p1_next.data .- ctx.p1.data + ctx.dp2.data .= p2_next.data .- ctx.p2.data + ctx.p1.data .+= ctx.dp1.data + ctx.p2.data .+= ctx.dp2.data + + # 2. x update + u_a(x, w, nablaw, phi, nablaphi; g, u, p1, p2, tdata) = + dot(w, phi) + + tau * ctx.alpha2 * dot(ctx.T(tdata, w), ctx.T(tdata, phi)) + + tau * ctx.beta * dot(ctx.S(w, nablaw), ctx.S(phi, nablaphi)) + + u_l(x, phi, nablaphi; g, u, p1, p2, tdata) = + dot(u, phi) - tau * ( + dot(p1, ctx.T(tdata, phi)) + + dot(p2, nablaphi) - + ctx.alpha2 * dot(g, ctx.T(tdata, phi))) + + A, b = assemble(u_new.space, u_a, u_l; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.tdata) + u_new.data .= A \ b + ctx.du.data .= u_new.data .- ctx.u.data + ctx.u.data .= u_new.data + + #ctx.du.data .= z.data + any(isnan, ctx.u.data) && error("encountered nan data") + any(isnan, ctx.p1.data) && error("encountered nan data") + any(isnan, ctx.p2.data) && error("encountered nan data") + + return ctx +end + +" +2004, Chambolle: dual semi-implicit algorithm +" +function step_d!(ctx::L1L2TVContext; tau) + # u is P1 + # p2 is essentially DP0 (technically may be DP1) + + # TODO: this might not be implementable without higher order elements + # need grad(div(p)) + return ctx +end + +function solve_primal!(u::FeFunction, ctx::L1L2TVContext) + u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) = + ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) + + ctx.beta * dot(ctx.S(u, nablau), ctx.S(phi, nablaphi)) + + u_l(x, phi, nablaphi; g, p1, p2, tdata) = + -dot(p1, ctx.T(tdata, phi)) - dot(p2, nablaphi) + + ctx.alpha2 * dot(g, ctx.T(tdata, phi)) + + # u = B^{-1} * (T^* p_1 - div p_2 - alpha2 * T^* g) + A, b = assemble(u.space, u_a, u_l; ctx.g, ctx.p1, ctx.p2, ctx.tdata) + u.data .= A \ b +end + +huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2 + +function estimate!(ctx::L1L2TVContext) + # FIXME: sign? + function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata) + alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * ( + huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) - + dot(ctx.T(tdata, u) - g, p1 / ctx.alpha1) + + ctx.gamma1 / 2 * norm(p1 / ctx.alpha1)^2) + lambdapart = iszero(ctx.lambda) ? 0. : ctx.lambda * ( + huber(norm(nablau), ctx.gamma2) - + dot(nablau, p2 / ctx.lambda) + + ctx.gamma2 / 2 * norm(p2 / ctx.lambda)^2) + bpart = 1 / 2 * ( + ctx.alpha2 * dot(ctx.T(tdata, w - u), ctx.T(tdata, w - u)) + + ctx.beta * dot(ctx.S(w, nablaw) - ctx.S(u, nablau), ctx.S(w, nablaw) - ctx.S(u, nablau))) + + return alpha1part + lambdapart + bpart + end + + w = FeFunction(ctx.u.space) + solve_primal!(w, ctx) + project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, + nablau = nabla(ctx.u), w, nablaw = nabla(w), ctx.tdata) +end + +# TODO: deprecate in favor of refine(mesh, marked_cells; fs...) +#function refine(ctx::L1L2TVContext, marked_cells; fs_...) +# fs = NamedTuple(fs_) +# +# hmesh = HMesh(ctx.mesh) +# refined_functions = refine!(hmesh, Set(marked_cells); +# ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2, +# fs...) +# new_mesh = refined_functions.u.space.mesh +# +# # TODO: tdata needs to be recreated for refinement +# new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S, +# ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2) +# +# fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs)) +# +# @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap) +# @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap) +# @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap) +# @assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap) +# @assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap) +# @assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap) +# @assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap) +# @assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap) +# +# new_ctx.est.data .= refined_functions.est.data +# new_ctx.g.data .= refined_functions.g.data +# new_ctx.u.data .= refined_functions.u.data +# new_ctx.p1.data .= refined_functions.p1.data +# new_ctx.p2.data .= refined_functions.p2.data +# new_ctx.du.data .= refined_functions.du.data +# new_ctx.dp1.data .= refined_functions.dp1.data +# new_ctx.dp2.data .= refined_functions.dp2.data +# +# return new_ctx, fs_new +#end + +# minimal Dörfler marking +function mark(ctx::L1L2TVContext; theta=0.5) + n = ncells(ctx.mesh) + esttotal = sum(ctx.est.data) + + cellerrors = collect(pairs(ctx.est.data)) + cellerrors_sorted = sort(cellerrors; lt = (x, y) -> x.second > y.second) + + marked_cells = Int[] + estacc = 0. + for (cell, error) in cellerrors_sorted + estacc >= theta * esttotal && break + push!(marked_cells, cell) + estacc += error + end + return marked_cells +end + + +function output(ctx::L1L2TVContext, filename, fs...) + print("save \"$filename\" ... ") + vtk = vtk_mesh(filename, ctx.mesh) + vtk_append!(vtk, fs...) + vtk_save(vtk) + return vtk +end + +function primal_energy(ctx::L1L2TVContext) + function integrand(x; g, u, nablau, tdata) + return ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) + + ctx.alpha2 / 2 * norm(ctx.T(tdata, u) - g)^2 + + ctx.beta / 2 * norm(ctx.S(u, nablau))^2 + + ctx.lambda * huber(norm(nablau), ctx.gamma2) + end + return integrate(ctx.mesh, integrand; ctx.g, ctx.u, + nablau = nabla(ctx.u), ctx.tdata) +end + +norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) + +norm_step(ctx::L1L2TVContext) = + sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh)) + +function norm_residual(ctx::L1L2TVContext) + w = FeFunction(ctx.u.space) + solve_primal!(w, ctx) + w.data .-= ctx.u.data + upart2 = norm_l2(w)^2 + + function integrand(x; g, u, nablau, p1, p2, tdata) + p1part = p1 * max(ctx.gamma1, norm(ctx.T(tdata, u) - g)) - + ctx.alpha1 * (ctx.T(tdata, u) - g) + p2part = p2 * max(ctx.gamma2, norm(nablau)) - + ctx.lambda * nablau + return norm(p1part)^2 + norm(p2part)^2 + end + ppart2 = integrate(ctx.mesh, integrand; ctx.g, ctx.u, + nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata) + + return sqrt(upart2 + ppart2) +end + +function denoise(img; name, params...) + m = 1 + img = from_img(img) # coord flip + #mesh = init_grid(img; type=:vertex) + mesh = init_grid(img, 5, 5) + + T(tdata, u) = u + S(u, nablau) = u + + ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) + + project_img!(ctx.g, img) + #interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) + #m = (size(img) .- 1) ./ 2 .+ 1 + #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) + + save_denoise(ctx, i) = + output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", + ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est) + + pvd = paraview_collection("output/$(ctx.name).pvd") + pvd[0] = save_denoise(ctx, 0) + + k = 0 + println("primal energy: $(primal_energy(ctx))") + while true + while true + k += 1 + step!(ctx) + estimate!(ctx) + pvd[k] = save_denoise(ctx, k) + println() + + norm_step_ = norm_step(ctx) + norm_residual_ = norm_residual(ctx) + + println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") + println("primal energy: $(primal_energy(ctx))") + println("norm_step: $(norm_step_)") + println("norm_residual: $(norm_residual_)") + + norm_step_ <= 1e-1 && break + end + marked_cells = mark(ctx; theta = 0.5) + println("refining ...") + ctx, _ = refine(ctx, marked_cells) + test_mesh(ctx.mesh) + + project_img!(ctx.g, img) + + k >= 100 && break + end + vtk_save(pvd) + return ctx +end + + +function denoise_pd(ctx, img; df=nothing, name, algorithm, params_...) + params = NamedTuple(params_) + m = 1 + img = from_img(img) # coord flip + mesh = init_grid(img; type=:vertex) + #mesh = init_grid(img, 5, 5) + + T(tdata, u) = u + S(u, nablau) = u + + ctx = L1L2TVContext(name, mesh, m; + T, tdata = nothing, S, + params.alpha1, params.alpha2, params.lambda, params.beta, + params.gamma1, params.gamma2) + + # semi-implicit primal dual parameters + gamma = ctx.alpha2 + ctx.beta # T = I, S = I + gamma /= 100 # kind of arbitrary? + + tau = 1e-1 + L = 100 + sigma = inv(tau * L^2) + theta = 1. + + #project_img!(ctx.g, img) + interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) + ctx.u.data .= ctx.g.data + + save_denoise(ctx, i) = + output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", + ctx.g, ctx.u, ctx.p1, ctx.p2) + + log!(x::Nothing; kwargs...) = x + function log!(df::DataFrame; k, norm_step, norm_residual) + push!(df, (; + k, + primal_energy = primal_energy(ctx), + norm_step, + norm_residual)) + println(NamedTuple(last(df))) + end + + #pvd = paraview_collection("output/$(ctx.name).pvd") + #pvd[0] = save_denoise(ctx, 0) + + k = 0 + println("primal energy: $(primal_energy(ctx))") + + while true + k += 1 + if algorithm == :pd1 + # no step size control + step_pd2!(ctx; sigma, tau, theta) + elseif algorithm == :pd2 + theta = 1 / sqrt(1 + 2 * gamma * tau) + tau *= theta + sigma /= theta + step_pd2!(ctx; sigma, tau, theta) + elseif algorithm == :newton + step!(ctx) + end + #pvd[k] = save_denoise(ctx, k) + + domain_factor = 1 / sqrt(area(mesh)) + norm_step_ = norm_step(ctx) * domain_factor + norm_residual_ = norm_residual(ctx) * domain_factor + + log!(df; k, norm_step = norm_step_, norm_residual = norm_residual_) + + #norm_residual_ < params.tol && norm_step_ < params.tol && break + #norm_step_ < params.tol && break + k >= params.max_iters && break + end + #pvd[1] = save_denoise(ctx, 1) + #vtk_save(pvd) + return ctx +end + + +function experiment_pd_comparison(ctx) + img = loadimg(joinpath(ctx.indir, "input.png")) + img = from_img(img) # coord flip + + algparams = ( + alpha1=0., alpha2=30., lambda=1., beta=0., + gamma1=1e-3, gamma2=1e-3, + tol = 1e-6, max_iters = 50, + ) + + df1 = DataFrame() + df2 = DataFrame() + df3 = DataFrame() + + denoise_pd(ctx, img; name="test", algorithm=:pd1, df = df1, algparams...); + denoise_pd(ctx, img; name="test", algorithm=:pd2, df = df2, algparams...); + denoise_pd(ctx, img; name="test", algorithm=:newton, df = df3, algparams...); + + energy_min = min(minimum(df1.primal_energy), minimum(df2.primal_energy), + minimum(df3.primal_energy)) + + #df1.primal_energy .-= energy_min + #df2.primal_energy .-= energy_min + #df3.primal_energy .-= energy_min + + CSV.write(joinpath(ctx.outdir, "semiimplicit.csv"), logfilter(df1)) + CSV.write(joinpath(ctx.outdir, "semiimplicit-accelerated.csv"), logfilter(df2)) + CSV.write(joinpath(ctx.outdir, "newton.csv"), logfilter(df3)) +end + +function inpaint(img, imgmask; name, params...) + size(img) == size(imgmask) || + throw(ArgumentError("non-matching dimensions")) + + m = 1 + img = from_img(img) # coord flip + imgmask = from_img(imgmask) # coord flip + mesh = init_grid(img; type=:vertex) + + # inpaint specific stuff + Vg = FeSpace(mesh, P1(), (1,)) + mask = FeFunction(Vg, name="mask") + + T(tdata, u) = isone(tdata[begin]) ? u : zero(u) + S(u, nablau) = u + + ctx = L1L2TVContext(name, mesh, m; T, tdata = mask, S, params...) + + # FIXME: currently dual grid only + interpolate!(mask, x -> imgmask[round.(Int, x)...]) + #interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1) + interpolate!(ctx.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] : 0.) + m = (size(img) .- 1) ./ 2 .+ 1 + interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3) + + save_inpaint(i) = + output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu", + ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask) + + pvd = paraview_collection("output/$(ctx.name).pvd") + pvd[0] = save_inpaint(0) + for i in 1:3 + step!(ctx) + estimate!(ctx) + pvd[i] = save_inpaint(i) + println() + end + return ctx +end + +function optflow(ctx) + # coord flip + imgf0 = from_img(ctx.params.imgf0) + imgf1 = from_img(ctx.params.imgf1) + maxflow = 4.6157 # Dimetrodon + size(imgf0) == size(imgf1) || + throw(ArgumentError("non-matching dimensions")) + + m = 2 + #mesh = init_grid(imgf0; type=:vertex) + mesh = init_grid(imgf0, (size(imgf0) .÷ 16)...) + #mesh = init_grid(imgf0) + + # optflow specific stuff + Vg = FeSpace(mesh, P1(), (1,)) + f0 = FeFunction(Vg, name="f0") + f1 = FeFunction(Vg, name="f1") + fw = FeFunction(Vg, name="fw") + + T(tdata, u) = tdata * u # tdata = nablafw + S(u, nablau) = nablau + #S(u, nablau) = u + + st = L1L2TVContext("run", mesh, m; T, tdata = nabla(fw), S, + ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta, + ctx.params.gamma1, ctx.params.gamma2) + + function warp!() + imgfw = warp_backwards(imgf1, sample(st.u)) + project_img!(fw, imgfw) + + # replace new tdata + st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fw), S, + st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2, + st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2) + + g_optflow(x; u, f0, fw, nablafw) = + nablafw * u - (fw - f0) + interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw = st.tdata) + end + + function reproject!() + project_img2!(f0, imgf0) + project_img2!(f1, imgf1) + end + reproject!() + warp!() + + save_step(i) = + output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"), + st.g, st.u, st.p1, st.p2, st.est, f0, f1, fw) + + i = 0 + pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd + pvd[i] = save_step(i) + while true + for j in 1:4 + norm_g_old = norm_l2(st.g) + for k in 1:5 + i += 1 + step!(st) + estimate!(st) + pvd[i] = save_step(i) + println() + end + warp!() + norm_g = norm_l2(st.g) + i += 1 + pvd[i] = save_step(i) + display(plot(colorflow(to_img(sample(st.u)); maxflow))) + end + i >= 50 && break + #continue + + marked_cells = mark(st; theta = 0.5) + #marked_cells = Set(axes(mesh.cells, 2)) + + println("refining ...") + + mesh, fs = refine(mesh, marked_cells; + st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2, + f0, f1, fw) + st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fs.fw), S, + st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2, + fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2) + f0, f1, fw = (fs.f0, fs.f1, fs.fw) + i += 1 + pvd[i] = save_step(i) + + println("reprojecting ...") + reproject!() + i += 1 + pvd[i] = save_step(i) + end + end + display(plot(colorflow(to_img(sample(st.u)); maxflow))) + + #CSV.write(joinpath(ctx.outdir, "energies.csv"), df) + #saveimg(joinpath(ctx.outdir, "output_glob.png"), fetch_u(states.glob)) + #savedata(ctx, "data.tex"; lambda=λ, beta=β, tau=τ, maxiters, energymin, + # width=size(fo, 2), height=size(fo, 1)) + return st +end + +function experiment_optflow_middlebury(ctx) + imgf0 = loadimg(joinpath(ctx.indir, "frame10.png")) + imgf1 = loadimg(joinpath(ctx.indir, "frame11.png")) + + ctx = Util.Context(ctx; imgf0, imgf1) + return optflow(ctx) +end