Skip to content
Snippets Groups Projects
Select Git revision
  • 21d0ce4b294c6c73691d1c00888b3d1f5f5dbb5d
  • master default protected
  • andreas/paper2
  • v1.0
  • v0.1
5 results

run_experiments.jl

Blame
  • run_experiments.jl 25.95 KiB
    using LinearAlgebra: norm, dot
    
    using Colors: Gray
    # avoid world-age-issues by preloading ColorTypes
    import ColorTypes
    import CSV
    import DataFrames: DataFrame
    import FileIO
    using OpticalFlowUtils
    using WriteVTK: paraview_collection
    using Plots
    
    using SemiSmoothNewton
    using SemiSmoothNewton: HMesh, ncells, refine, area, mesh_size, ndofs
    using SemiSmoothNewton: project_img!, project_img2!, project!
    using SemiSmoothNewton: vtk_mesh, vtk_append!, vtk_save
    
    include("util.jl")
    isdefined(Main, :Revise) && Revise.track(joinpath(@__DIR__, "util.jl"))
    
    using .Util
    
    _toimg(arr) = Gray.(clamp.(arr, 0., 1.))
    #loadimg(x) = reverse(transpose(Float64.(FileIO.load(x))); dims = 2)
    #saveimg(io, x) = FileIO.save(io, transpose(reverse(_toimg(x); dims = 2)))
    loadimg(x) = Float64.(FileIO.load(x))
    saveimg(io, x) = FileIO.save(io, _toimg(x))
    
    from_img(arr) = permutedims(reverse(arr; dims = 1))
    to_img(arr) = permutedims(reverse(arr; dims = 2))
    function to_img(arr::AbstractArray{<:Any,3})
        out = permutedims(reverse(arr; dims = (1,3)), (1, 3, 2))
        out[1, :, :] .= .-out[1, :, :]
        return out
    end
    
    
    """
        logfilter(dt; a=20)
    
    Reduce dataset such that the resulting dataset would have approximately
    equidistant datapoints on a log-scale. `a` controls density.
    """
    logfilter(dt; a=20) = filter(:k => k -> round(a*log(k+1)) - round(a*log(k)) > 0, dt)
    
    struct L1L2TVContext{M, Ttype, Stype}
        name::String
        mesh::M
        d::Int # = ndims_domain(mesh)
        m::Int
    
        T::Ttype
        tdata
        S::Stype
    
        alpha1::Float64
        alpha2::Float64
        beta::Float64
        lambda::Float64
        gamma1::Float64
        gamma2::Float64
    
        est::FeFunction
        g::FeFunction
        u::FeFunction
        p1::FeFunction
        p2::FeFunction
        du::FeFunction
        dp1::FeFunction
        dp2::FeFunction
    end
    
    function L1L2TVContext(name, mesh, m; T, tdata, S,
    	alpha1, alpha2, beta, lambda, gamma1, gamma2)
        d = ndims_domain(mesh)
    
        Vest = FeSpace(mesh, DP0(), (1,))
        Vg = FeSpace(mesh, P1(), (1,))
        Vu = FeSpace(mesh, P1(), (m,))
        Vp1 = FeSpace(mesh, DP1(), (1,))
        Vp2 = FeSpace(mesh, DP0(), (m, d))
    
        est = FeFunction(Vest, name="est")
        g = FeFunction(Vg, name="g")
        u = FeFunction(Vu, name="u")
        p1 = FeFunction(Vp1, name="p1")
        p2 = FeFunction(Vp2, name="p2")
        du = FeFunction(Vu; name = "du")
        dp1 = FeFunction(Vp1; name = "dp1")
        dp2 = FeFunction(Vp2; name = "dp2")
    
        est.data .= 0
        g.data .= 0
        u.data .= 0
        p1.data .= 0
        p2.data .= 0
        du.data .= 0
        dp1.data .= 0
        dp2.data .= 0
    
        return L1L2TVContext(name, mesh, d, m, T, tdata, S,
    	alpha1, alpha2, beta, lambda, gamma1, gamma2,
    	est, g, u, p1, p2, du, dp1, dp2)
    end
    
    function p1_project!(p1, alpha1)
        p1.space.element == DP0() || p1.space.element == P1() ||
            p1.space.element == DP1() ||
            throw(ArgumentError("element unsupported"))
        p1.data .= clamp.(p1.data, -alpha1, alpha1)
    end
    
    function p2_project!(p2, lambda)
        p2.space.element == DP0() || p2.space.element == DP1() ||
            p2.space.element == P1() ||
            throw(ArgumentError("element unsupported"))
        p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
        for i in axes(p2d, 2)
            p2in = norm(p2d[:, i])
            if p2in > lambda
                p2d[:, i] .*= lambda ./ p2in
            end
        end
    end
    
    function step!(ctx::L1L2TVContext)
        T = ctx.T
        S = ctx.S
        alpha1 = ctx.alpha1
        alpha2 = ctx.alpha2
        beta = ctx.beta
        lambda = ctx.lambda
        gamma1 = ctx.gamma1
        gamma2 = ctx.gamma2
    
        function du_a(x_, du, nabladu, phi, nablaphi; g, u, nablau, p1, p2, tdata)
    	m1 = max(gamma1, norm(T(tdata, u) - g))
    	cond1 = norm(T(tdata, u) - g) > gamma1 ?
    	    dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
    	    zero(p1)
    	a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) -
    	    dot(cond1, T(tdata, phi))
    
    	m2 = max(gamma2, norm(nablau))
    	cond2 = norm(nablau) > gamma2 ?
    	    dot(nablau, nabladu) / norm(nablau)^2 * p2 :
    	    zero(p2)
    	a2 = lambda / m2 * dot(nabladu, nablaphi) -
    	    dot(cond2, nablaphi)
    
    	aB = alpha2 * dot(T(tdata, du), T(tdata, phi)) +
    	    beta * dot(S(du, nabladu), S(phi, nablaphi))
    
    	return a1 + a2 + aB
        end
    
        function du_l(x_, phi, nablaphi; g, u, nablau, p1, p2, tdata)
    	aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) +
    	    beta * dot(S(u, nablau), S(phi, nablaphi))
    	m1 = max(gamma1, norm(T(tdata, u) - g))
    	p1part = alpha1 / m1 * dot(T(tdata, u) - g, T(tdata, phi))
    	m2 = max(gamma2, norm(nablau))
    	p2part = lambda / m2 * dot(nablau, nablaphi)
    	gpart = alpha2 * dot(g, T(tdata, phi))
    
    	return -aB - p1part - p2part + gpart
        end
    
        # solve du
        print("assemble ... ")
        A, b = assemble(ctx.du.space, du_a, du_l;
            ctx.g, ctx.u, nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata)
        print("solve ... ")
        ctx.du.data .= A \ b
    
    
        # solve dp1
        function dp1_update(x_; g, u, p1, du, tdata)
    	m1 = max(gamma1, norm(T(tdata, u) - g))
    	cond = norm(T(tdata, u) - g) > gamma1 ?
    	    dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
    	    zero(p1)
    	return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond
        end
        interpolate!(ctx.dp1, dp1_update;
    	ctx.g, ctx.u, ctx.p1, ctx.du, ctx.tdata)
    
        # solve dp2
        function dp2_update(x_; u, nablau, p2, du, nabladu)
    	m2 = max(gamma2, norm(nablau))
    	cond = norm(nablau) > gamma2 ?
    	    dot(nablau, nabladu) / norm(nablau)^2 * p2 :
    	    zero(p2)
    	return -p2 + lambda / m2 * (nablau + nabladu) - cond
        end
        interpolate!(ctx.dp2, dp2_update;
            ctx.u, nablau = nabla(ctx.u), ctx.p2, ctx.du, nabladu = nabla(ctx.du))
    
        # newton update
        theta = 1.
        ctx.u.data .+= theta * ctx.du.data
        ctx.p1.data .+= theta * ctx.dp1.data
        ctx.p2.data .+= theta * ctx.dp2.data
    
        # reproject p1, p2
        p1_project!(ctx.p1, ctx.alpha1)
        p2_project!(ctx.p2, ctx.lambda)
    end
    
    "
    2010, Chambolle and Pock: primal-dual semi-implicit algorithm
    2017, Alkämper and Langer: fem dualisation
    "
    function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
        # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
        # changed alpha2 -> alpha2 / 2
        # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
        if ctx.m != 1 || ctx.lambda != 1. || ctx.beta != 0.
    	error("unsupported parameters")
        end
        beta = tau * ctx.alpha1 / (1 + tau * ctx.alpha2)
    
        # u is P1
        # p2 is essentially DP0 (technically may be DP1)
    
        # 1. y update
        u_new = FeFunction(ctx.u.space) # x_bar only used here
        u_new.data .= ctx.u.data .+ theta .* ctx.du.data
    
        p2_next = FeFunction(ctx.p2.space)
    
        function p2_update(x_; p2, nablau)
    	return p2 + sigma * nablau
        end
        interpolate!(p2_next, p2_update; ctx.p2, nablau = nabla(u_new))
    
        p2_project!(p2_next, ctx.lambda)
        ctx.dp2.data .= p2_next.data .- ctx.p2.data
    
        ctx.p2.data .+= ctx.dp2.data
    
        # 2.
        u_a(x, z, nablaz, phi, nablaphi; g, u, p2) =
    	dot(z, phi)
    
        u_l(x, phi, nablaphi; u, g, p2) =
    	(dot(u + tau * ctx.alpha2 * g, phi) + tau * dot(p2, -nablaphi)) /
    	    (1 + tau * ctx.alpha2)
    
        # z = 1 / (1 + tau * alpha2) *
        #   (u + tau * alpha2 * g + tau * div(p))
        z = FeFunction(ctx.u.space)
        A, b = assemble(z.space, u_a, u_l; ctx.g, ctx.u, ctx.p2)
        z.data .= A \ b
    
        function u_update!(u, z, g, beta)
    	u.space.element::P1
    	g.space.element::P1
    	for i in eachindex(u.data)
    	    if z.data[i] - beta >= g.data[i]
    		u.data[i] = z.data[i] - beta
    	    elseif z.data[i] + beta <= g.data[i]
    		u.data[i] = z.data[i] + beta
    	    else
    		u.data[i] = g.data[i]
    	    end
    	end
        end
        u_update!(u_new, z, ctx.g, beta)
    
        # 3.
        # note: step-size control not implemented, since \nabla G^* is not 1/gamma continuous
        ctx.du.data .= u_new.data .- ctx.u.data
        ctx.u.data .= u_new.data
    
        #ctx.du.data .= z.data
        any(isnan, ctx.u.data) && error("encountered nan data")
        any(isnan, ctx.p2.data) && error("encountered nan data")
    
        return ctx
    end
    
    "
    2010, Chambolle and Pock: accelerated primal-dual semi-implicit algorithm
    "
    function step_pd2!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
        # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
    
        # u is P1
        # p2 is essentially DP0 (technically may be DP1)
    
        # 1. y update
        u_new = FeFunction(ctx.u.space) # x_bar only used here
        u_new.data .= ctx.u.data .+ theta .* ctx.du.data
    
        p1_next = FeFunction(ctx.p1.space)
        p2_next = FeFunction(ctx.p2.space)
    
        function p1_update(x_; p1, g, u, tdata)
    	ctx.alpha1 == 0 && return zero(p1)
    	return (p1 + sigma * (ctx.T(tdata, u) - g)) /
    	    (1 + ctx.gamma1 * sigma / ctx.alpha1)
        end
        interpolate!(p1_next, p1_update; ctx.p1, ctx.g, ctx.u, ctx.tdata)
    
        function p2_update(x_; p2, nablau)
    	ctx.lambda == 0 && return zero(p2)
    	return (p2 + sigma * nablau) /
    	    (1 + ctx.gamma2 * sigma / ctx.lambda)
        end
        interpolate!(p2_next, p2_update; ctx.p2, nablau = nabla(u_new))
    
        # reproject p1, p2
        p1_project!(p1_next, ctx.alpha1)
        p2_project!(p2_next, ctx.lambda)
    
        ctx.dp1.data .= p1_next.data .- ctx.p1.data
        ctx.dp2.data .= p2_next.data .- ctx.p2.data
        ctx.p1.data .+= ctx.dp1.data
        ctx.p2.data .+= ctx.dp2.data
    
        # 2. x update
        u_a(x, w, nablaw, phi, nablaphi; g, u, p1, p2, tdata) =
    	dot(w, phi) +
    	tau * ctx.alpha2 * dot(ctx.T(tdata, w), ctx.T(tdata, phi)) +
    	tau * ctx.beta * dot(ctx.S(w, nablaw), ctx.S(phi, nablaphi))
    
        u_l(x, phi, nablaphi; g, u, p1, p2, tdata) =
    	dot(u, phi) - tau * (
    	    dot(p1, ctx.T(tdata, phi)) +
    	    dot(p2, nablaphi) -
    	    ctx.alpha2 * dot(g, ctx.T(tdata, phi)))
    
        A, b = assemble(u_new.space, u_a, u_l; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.tdata)
        u_new.data .= A \ b
        ctx.du.data .= u_new.data .- ctx.u.data
        ctx.u.data .= u_new.data
    
        #ctx.du.data .= z.data
        any(isnan, ctx.u.data) && error("encountered nan data")
        any(isnan, ctx.p1.data) && error("encountered nan data")
        any(isnan, ctx.p2.data) && error("encountered nan data")
    
        return ctx
    end
    
    "
    2004, Chambolle: dual semi-implicit algorithm
    "
    function step_d!(ctx::L1L2TVContext; tau)
        # u is P1
        # p2 is essentially DP0 (technically may be DP1)
    
        # TODO: this might not be implementable without higher order elements
        # need grad(div(p))
        return ctx
    end
    
    function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
        u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
    	ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
    	    ctx.beta * dot(ctx.S(u, nablau), ctx.S(phi, nablaphi))
    
        u_l(x, phi, nablaphi; g, p1, p2, tdata) =
    	-dot(p1, ctx.T(tdata, phi)) - dot(p2, nablaphi) +
    	    ctx.alpha2 * dot(g, ctx.T(tdata, phi))
    
        # u = -B^{-1} * (T^* p_1 + nabla^* p_2 - alpha2 * T^* g)
        # solve:
        # <B u, phi> = -<p_1, T phi> - <p_2, nablaphi> + alpha_2 * <g, T phi>
        A, b = assemble(u.space, u_a, u_l; ctx.g, ctx.p1, ctx.p2, ctx.tdata)
        u.data .= A \ b
    end
    
    huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
    
    function estimate!(ctx::L1L2TVContext)
        function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
    	alpha1part =
                ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) -
    	    dot(ctx.T(tdata, u) - g, p1) +
                (iszero(ctx.alpha1) ? 0. :
                    ctx.gamma1 / (2 * ctx.alpha1) * norm(p1)^2)
    	lambdapart =
                ctx.lambda * huber(norm(nablau), ctx.gamma2) -
    	    dot(nablau, p2) +
                (iszero(ctx.lambda) ? 0. :
                    ctx.gamma2 / (2 * ctx.lambda) * norm(p2)^2)
            # avoid non-negative rounding errors
            alpha1part = max(0, alpha1part)
            lambdapart = max(0, lambdapart)
    	bpart = 1 / 2 * (
    	    ctx.alpha2 * dot(ctx.T(tdata, w - u), ctx.T(tdata, w - u)) +
    	    ctx.beta * dot(ctx.S(w, nablaw) - ctx.S(u, nablau),
                    ctx.S(w, nablaw) - ctx.S(u, nablau)))
    
    	return alpha1part + lambdapart + bpart
        end
    
        w = FeFunction(ctx.u.space)
        solve_primal!(w, ctx)
        # TODO: find better name: is actually a cell-wise integration
        project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2,
            nablau = nabla(ctx.u), w, nablaw = nabla(w), ctx.tdata)
    end
    
    estimate_error(st::L1L2TVContext) =
        sqrt(sum(st.est.data) / area(st.mesh))
    
    # minimal Dörfler marking
    function mark(ctx::L1L2TVContext; theta=0.5)
        n = ncells(ctx.mesh)
        esttotal = sum(ctx.est.data)
    
        cellerrors = collect(pairs(ctx.est.data))
        cellerrors_sorted = sort(cellerrors; lt = (x, y) -> x.second > y.second)
    
        marked_cells = Int[]
        estacc = 0.
        for (cell, error) in cellerrors_sorted
    	estacc >= theta * esttotal &&  break
    	push!(marked_cells, cell)
    	estacc += error
        end
        return marked_cells
    end
    
    
    function output(ctx::L1L2TVContext, filename, fs...)
        print("save \"$filename\" ... ")
        vtk = vtk_mesh(filename, ctx.mesh)
        vtk_append!(vtk, fs...)
        vtk_save(vtk)
        return vtk
    end
    
    function primal_energy(ctx::L1L2TVContext)
        function integrand(x; g, u, nablau, tdata)
    	return ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
    	    ctx.alpha2 / 2 * norm(ctx.T(tdata, u) - g)^2 +
    	    ctx.beta / 2 * norm(ctx.S(u, nablau))^2 +
    	    ctx.lambda * huber(norm(nablau), ctx.gamma2)
        end
        return integrate(ctx.mesh, integrand; ctx.g, ctx.u,
            nablau = nabla(ctx.u), ctx.tdata)
    end
    
    norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
    
    norm_step(ctx::L1L2TVContext) =
        sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh))
    
    function norm_residual(ctx::L1L2TVContext)
        w = FeFunction(ctx.u.space)
        solve_primal!(w, ctx)
        w.data .-= ctx.u.data
        upart2 = norm_l2(w)^2
    
        function integrand(x; g, u, nablau, p1, p2, tdata)
    	p1part = p1 * max(ctx.gamma1, norm(ctx.T(tdata, u) - g)) -
    	    ctx.alpha1 * (ctx.T(tdata, u) - g)
    	p2part = p2 * max(ctx.gamma2, norm(nablau)) -
    	    ctx.lambda * nablau
    	return norm(p1part)^2 + norm(p2part)^2
        end
        ppart2 = integrate(ctx.mesh, integrand; ctx.g, ctx.u,
            nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata)
    
        return sqrt(upart2 + ppart2)
    end
    
    function denoise(img; name, params...)
        m = 1
        img = from_img(img) # coord flip
        #mesh = init_grid(img; type=:vertex)
        mesh = init_grid(img, 5, 5)
    
        T(tdata, u) = u
        S(u, nablau) = u
    
        ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
    
        project_img!(ctx.g, img)
        #interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
        #m = (size(img) .- 1) ./ 2 .+ 1
        #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
    
        save_denoise(ctx, i) =
    	output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
    	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
    
        pvd = paraview_collection("output/$(ctx.name).pvd")
        pvd[0] = save_denoise(ctx, 0)
    
        k = 0
        println("primal energy: $(primal_energy(ctx))")
        while true
    	while true
    	    k += 1
    	    step!(ctx)
    	    estimate!(ctx)
    	    pvd[k] = save_denoise(ctx, k)
    	    println()
    
    	    norm_step_ = norm_step(ctx)
    	    norm_residual_ = norm_residual(ctx)
    
    	    println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
    	    println("primal energy: $(primal_energy(ctx))")
    	    println("norm_step: $(norm_step_)")
    	    println("norm_residual: $(norm_residual_)")
    
                norm_step_ <= 1e-1 && break
    	end
    	marked_cells = mark(ctx; theta = 0.5)
    	println("refining ...")
    	ctx, _ = refine(ctx, marked_cells)
    	test_mesh(ctx.mesh)
    
    	project_img!(ctx.g, img)
    
    	k >= 100 && break
        end
        vtk_save(pvd)
        return ctx
    end
    
    
    function denoise_pd(st, img; df=nothing, name, algorithm, params_...)
        params = NamedTuple(params_)
        m = 1
        img = from_img(img) # coord flip
        mesh = init_grid(img)
        #mesh = init_grid(img, 5, 5)
    
        T(tdata, u) = u
        S(u, nablau) = u
    
        st = L1L2TVContext(name, mesh, m;
            T, tdata = nothing, S,
            params.alpha1, params.alpha2, params.lambda, params.beta,
            params.gamma1, params.gamma2)
    
        # semi-implicit primal dual parameters
        mu = st.alpha2 + st.beta # T = I, S = I
        #mu /= 100 # kind of arbitrary?
    
        tau = 1e-1
        L = 100
        sigma = inv(tau * L^2)
        theta = 1.
    
        #project_img!(st.g, img)
        interpolate!(st.g, x -> interpolate_bilinear(img, x))
        st.u.data .= st.g.data
    
        save_denoise(st, i) =
    	output(st, "output/$(st.name)_$(lpad(i, 5, '0')).vtu",
    	    st.g, st.u, st.p1, st.p2)
    
        log!(x::Nothing; kwargs...) = x
        function log!(df::DataFrame; kwargs...)
            row = NamedTuple(kwargs)
            push!(df, row)
            #println(df)
    	println(row)
        end
    
        #pvd = paraview_collection("output/$(st.name).pvd")
        #pvd[0] = save_denoise(st, 0)
    
        k = 0
        println("primal energy: $(primal_energy(st))")
    
        while true
    	k += 1
    	if algorithm == :pd1
    	    # no step size control
    	    step_pd2!(st; sigma, tau, theta)
    	elseif algorithm == :pd2
    	    theta = 1 / sqrt(1 + 2 * mu * tau)
    	    tau *= theta
    	    sigma /= theta
    	    step_pd2!(st; sigma, tau, theta)
    	elseif algorithm == :newton
    	    step!(st)
    	end
    	#pvd[k] = save_denoise(st, k)
    
    	domain_factor = 1 / sqrt(area(mesh))
    	norm_step_ = norm_step(st) * domain_factor
            #estimate!(st)
    
            log!(df; k,
                norm_step = norm_step_,
                #est = estimate_error(st),
                primal_energy = primal_energy(st))
    
    	#norm_residual_ < params.tol && norm_step_ < params.tol && break
    	haskey(params, :tol) && norm_step_ < params.tol && break
            haskey(params, :max_iters) && k >= params.max_iters && break
        end
        #pvd[1] = save_denoise(st, 1)
        #vtk_save(pvd)
        return st
    end
    
    
    function experiment_pd_comparison(ctx)
        img = loadimg(joinpath(ctx.indir, "input.png"))
        #img = [0. 0.; 1. 0.]
    
        algparams = (
            alpha1=0., alpha2=30., lambda=1., beta=0.,
            gamma1=1e-2, gamma2=1e-3,
            tol = 1e-10,
            max_iters = 10000,
            )
    
        df1 = DataFrame()
        df2 = DataFrame()
        df3 = DataFrame()
    
        st1 = denoise_pd(ctx, img; name="test",
            algorithm=:pd1, df = df1, algparams...);
        st2 = denoise_pd(ctx, img; name="test",
            algorithm=:pd2, df = df2, algparams...);
        st3 = denoise_pd(ctx, img; name="test",
            algorithm=:newton, df = df3, algparams...);
    
        energy_min = min(minimum(df1.primal_energy), minimum(df2.primal_energy),
    		     minimum(df3.primal_energy))
        df1[!, :energy_min] .= energy_min
        df2[!, :energy_min] .= energy_min
        df3[!, :energy_min] .= energy_min
    
        CSV.write(joinpath(ctx.outdir, "semi-implicit.csv"),
            logfilter(df1))
        CSV.write(joinpath(ctx.outdir, "semi-implicit-accelerated.csv"),
            logfilter(df2))
        CSV.write(joinpath(ctx.outdir, "newton.csv"),
            logfilter(df3))
    
        saveimg(joinpath(ctx.outdir, "input.png"),
            img)
        saveimg(joinpath(ctx.outdir, "semi-implicit.png"),
            to_img(sample(st1.u)))
        saveimg(joinpath(ctx.outdir, "semi-implicit-accelerated.png"),
            to_img(sample(st2.u)))
        saveimg(joinpath(ctx.outdir, "newton.png"),
            to_img(sample(st3.u)))
    
        savedata(joinpath(ctx.outdir, "data.tex");
            energy_min, algparams...)
    end
    
    function denoise_approximation(ctx)
        domain_factor = 1 / sqrt(area(ctx.params.mesh))
    
        T(tdata, u) = u
        S(u, nablau) = u
    
        st = L1L2TVContext(ctx.params.name, ctx.params.mesh, 1;
            T, tdata = nothing, S,
            ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta,
            ctx.params.gamma1, ctx.params.gamma2)
    
        #project_img!(st.g, img)
        interpolate!(st.g, x -> interpolate_bilinear(ctx.params.img, x))
        st.u.data .= st.g.data
    
        save_vtk(st, i) =
            output(st,
                joinpath(ctx.outdir, "$(ctx.params.name)_$(lpad(i, 5, '0')).vtu"),
    	    st.g, st.u, st.p1, st.p2, st.est)
    
        log!(x::Nothing; kwargs...) = x
        function log!(df::DataFrame; kwargs...)
            row = NamedTuple(kwargs)
            push!(df, row)
            println(df)
    	#println(row)
        end
    
        pvd_path = joinpath(ctx.outdir, "$(ctx.params.name).pvd")
        pvd = paraview_collection(pvd_path) do pvd
    
            pvd[0] = save_vtk(st, 0)
            println("primal energy: $(primal_energy(st))")
    
            k = 0
            while true
                step!(st)
    
                norm_step_ = norm_step(st) * domain_factor
                if haskey(ctx.params, :tol) && norm_step_ < ctx.params.tol
                    k += 1
                    norm_residual_ = norm_residual(st) * domain_factor
                    estimate!(st)
    
                    pvd[k] = save_vtk(st, k)
                    log!(ctx.params.df; k,
                        ndofs = ndofs(st.u.space),
                        hmax = mesh_size(st.mesh),
                        norm_step = norm_step_, norm_residual = norm_residual_,
                        primal_energy = primal_energy(st),
                        est = sqrt(sum(st.est.data)) * domain_factor,
                    )
    
                    marked_cells = ctx.params.adaptive ?
                        mark(st; theta = 0.5) : # estimator + dörfler
                        Set(axes(st.mesh.cells, 2))
    
                    mesh, fs = refine(st.mesh, marked_cells;
                        st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
                    st = L1L2TVContext("run", mesh, st.d, st.m, T, nothing, S,
                        st.alpha1, st.alpha2, st.beta, st.lambda,
                        st.gamma1, st.gamma2,
                        fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
                end
            end
        end
    
        return st
    end
    
    function experiment_approximation(ctx)
        df = DataFrame()
        #img = from_img([1. 0.; 0. 0.])
        img = from_img([0. 0.; 1. 0.])
        #mesh = init_grid(img; type=:vertex)
        mesh = init_grid(img;)
    
        denoise_approximation(Util.Context(ctx; name = "test", df,
            img, mesh,
            alpha1 = 0., alpha2 = 30., lambda = 1., beta = 0.,
            #alpha1 = 0.5, alpha2 = 0., lambda = 0., beta = 1.,
            gamma1 = 1e-5, gamma2 = 1e-5,
            tol = 1e-10, adaptive = true,
        ))
    end
    
    function inpaint(img, imgmask; name, params...)
        size(img) == size(imgmask) ||
    	throw(ArgumentError("non-matching dimensions"))
    
        m = 1
        img = from_img(img) # coord flip
        imgmask = from_img(imgmask) # coord flip
        mesh = init_grid(img; type=:vertex)
    
        # inpaint specific stuff
        Vg = FeSpace(mesh, P1(), (1,))
        mask = FeFunction(Vg, name="mask")
    
        T(tdata, u) = isone(tdata[begin]) ? u : zero(u)
        S(u, nablau) = u
    
        ctx = L1L2TVContext(name, mesh, m; T, tdata = mask, S, params...)
    
        # FIXME: currently dual grid only
        interpolate!(mask, x -> imgmask[round.(Int, x)...])
        #interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1)
        interpolate!(ctx.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] :  0.)
        m = (size(img) .- 1) ./ 2 .+ 1
        interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
    
        save_inpaint(i) =
    	output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
    	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask)
    
        pvd = paraview_collection("output/$(ctx.name).pvd")
        pvd[0] = save_inpaint(0)
        for i in 1:3
    	step!(ctx)
    	estimate!(ctx)
    	pvd[i] = save_inpaint(i)
    	println()
        end
        return ctx
    end
    
    function optflow(ctx)
        # coord flip
        imgf0 = from_img(ctx.params.imgf0)
        imgf1 = from_img(ctx.params.imgf1)
        maxflow = 4.6157 # Dimetrodon
        size(imgf0) == size(imgf1) ||
    	throw(ArgumentError("non-matching dimensions"))
    
        m = 2
        #mesh = init_grid(imgf0; type=:vertex)
        mesh = init_grid(imgf0, (size(imgf0) .÷ 16)...)
        #mesh = init_grid(imgf0)
    
        # optflow specific stuff
        Vg = FeSpace(mesh, P1(), (1,))
        f0 = FeFunction(Vg, name="f0")
        f1 = FeFunction(Vg, name="f1")
        fw = FeFunction(Vg, name="fw")
    
        T(tdata, u) = tdata * u # tdata = nablafw
        S(u, nablau) = nablau
        #S(u, nablau) = u
    
        st = L1L2TVContext("run", mesh, m; T, tdata = nabla(fw), S,
            ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta,
            ctx.params.gamma1, ctx.params.gamma2)
    
        function warp!()
            imgfw = warp_backwards(imgf1, sample(st.u))
            project_img!(fw, imgfw)
    
            # replace new tdata
            st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fw), S,
                st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2,
                st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
    
            g_optflow(x; u, f0, fw, nablafw) =
                nablafw * u - (fw - f0)
            interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw = st.tdata)
        end
    
        function reproject!()
            project_img2!(f0, imgf0)
            project_img2!(f1, imgf1)
        end
        reproject!()
        warp!()
    
        save_step(i) =
            output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"),
    	    st.g, st.u, st.p1, st.p2, st.est, f0, f1, fw)
    
        i = 0
        pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd
            pvd[i] = save_step(i)
            while true
                for j in 1:4
                    norm_g_old = norm_l2(st.g)
                    for k in 1:5
                        i += 1
                        step!(st)
                        estimate!(st)
                        pvd[i] = save_step(i)
                        println()
                    end
                    warp!()
                    norm_g = norm_l2(st.g)
                    i += 1
                    pvd[i] = save_step(i)
                    display(plot(colorflow(to_img(sample(st.u)); maxflow)))
                end
                i >= 50 && break
                #continue
    
                marked_cells = mark(st; theta = 0.5)
                #marked_cells = Set(axes(mesh.cells, 2))
    
                println("refining ...")
    
                mesh, fs = refine(mesh, marked_cells;
                    st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2,
                    f0, f1, fw)
                st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fs.fw), S,
                    st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2,
                    fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
                f0, f1, fw = (fs.f0, fs.f1, fs.fw)
                    i += 1
                    pvd[i] = save_step(i)
    
                println("reprojecting ...")
                reproject!()
                    i += 1
                    pvd[i] = save_step(i)
            end
        end
        display(plot(colorflow(to_img(sample(st.u)); maxflow)))
    
        #CSV.write(joinpath(ctx.outdir, "energies.csv"), df)
        #saveimg(joinpath(ctx.outdir, "output_glob.png"), fetch_u(states.glob))
        #savedata(ctx, "data.tex"; lambda=λ, beta=β, tau=τ, maxiters, energymin,
        #    width=size(fo, 2), height=size(fo, 1))
        return st
    end
    
    function experiment_optflow_middlebury(ctx)
        imgf0 = loadimg(joinpath(ctx.indir, "frame10.png"))
        imgf1 = loadimg(joinpath(ctx.indir, "frame11.png"))
    
        ctx = Util.Context(ctx; imgf0, imgf1)
        return optflow(ctx)
    end