Select Git revision
      
  run_experiments.jl
  run_experiments.jl  38.84 KiB 
using LinearAlgebra: I, det, dot, norm, normalize
using SparseArrays: sparse
using Statistics: mean
using Colors: Gray
# avoid world-age-issues by preloading ColorTypes
import ColorTypes
import CSV
import DataFrames: DataFrame
import FileIO
using ForwardDiff: jacobian
using ImageQualityIndexes: assess_psnr, assess_ssim
using OpticalFlowUtils
using WriteVTK: paraview_collection
using Plots
using StaticArrays: MVector, SArray, SMatrix, SVector, SA
using SemiSmoothNewton
using SemiSmoothNewton: HMesh, ncells, refine, area, mesh_size, ndofs, diam,
    elmap
using SemiSmoothNewton: project!, project_l2_lagrange!, project_qi_lagrange!,
    project_l2_pixel!
using SemiSmoothNewton: vtk_mesh, vtk_append!, vtk_save
include("util.jl")
isdefined(Main, :Revise) && Revise.track(joinpath(@__DIR__, "util.jl"))
using .Util
grayclamp(value) = Gray(clamp(value, 0., 1.))
loadimg(x) = Float64.(FileIO.load(x))
saveimg(io, x::Array{<:Gray}) = FileIO.save(io, grayclamp.(x))
saveimg(io, x) = FileIO.save(io, x)
function saveimgdiff(io, f0, f1)
    n = 2^8 # colors
    cmap = colormap("RdBu", n)
    k(v0, v1) = cmap[clamp(ceil(Int, n * ((v1 - v0) / 2 + 0.5)), 1, n)]
    saveimg(io, k.(f0, f1))
end
# convert image to/from image coordinate system
from_img(arr) = permutedims(reverse(arr; dims = 1))
to_img(arr) = permutedims(reverse(arr; dims = 2))
function to_img(arr::AbstractArray{<:Any,3})
    # for flow fields handle flow direction in first dimension too
    out = permutedims(reverse(arr; dims = (1,3)), (1, 3, 2))
    out[1, :, :] .= .-out[1, :, :]
    return out
end
"""
    logfilter(dt; a=20)
Reduce dataset such that the resulting dataset would have approximately
equidistant datapoints on a log-scale. `a` controls density.
"""
logfilter(dt; a=20) = filter(:k => k -> round(a*log(k+1)) - round(a*log(k)) > 0, dt)
struct L1L2TVState{M, m, Ttype, Stype}
    mesh::M
    T::Ttype
    tdata
    S::Stype
    alpha1::Float64
    alpha2::Float64
    beta::Float64    lambda::Float64
    gamma1::Float64
    gamma2::Float64
    est::FeFunction
    g::FeFunction
    u::FeFunction
    p1::FeFunction
    p2::FeFunction
    du::FeFunction
    dp1::FeFunction
    dp2::FeFunction
end
# for swapping out mesh and discrete functions
# TODO: any of state, mesh or functions should probably be mutable instead
function L1L2TVState(st::L1L2TVState; mesh, tdata, est, g, u, p1, p2, du, dp1, dp2)
    mesh === tdata.space.mesh == est.space.mesh == g.space.mesh ==
            u.space.mesh == p1.space.mesh == p2.space.mesh ==
            du.space.mesh == dp1.space.mesh == dp2.space.mesh ||
        throw(ArgumentError("functions have different meshes"))
    return L1L2TVState{typeof(mesh),ndims_u_codomain(st),typeof(st.T),typeof(st.S)}(
        mesh, st.T, tdata, st.S,
	st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2,
	est, g, u, p1, p2, du, dp1, dp2)
end
# usual constructor
function L1L2TVState{m}(mesh; T, tdata, S,
	alpha1, alpha2, beta, lambda, gamma1, gamma2) where m
    alpha2 > 0 || beta > 0 ||
        throw(ArgumentError("operator B is singular with these parameters"))
    d = ndims_domain(mesh)
    Vest = FeSpace(mesh, DP0(), (1,))
    Vg = FeSpace(mesh, P1(), (1,))
    Vu = FeSpace(mesh, P1(), (m,))
    Vp1 = FeSpace(mesh, P1(), (1,))
    Vp2 = FeSpace(mesh, DP0(), (m, d))
    est = FeFunction(Vest, name = "est")
    g = FeFunction(Vg, name = "g")
    u = FeFunction(Vu, name = "u")
    p1 = FeFunction(Vp1, name = "p1")
    p2 = FeFunction(Vp2, name = "p2")
    du = FeFunction(Vu; name = "du")
    dp1 = FeFunction(Vp1; name = "dp1")
    dp2 = FeFunction(Vp2; name = "dp2")
    est.data .= 0
    g.data .= 0
    u.data .= 0
    p1.data .= 0
    p2.data .= 0
    du.data .= 0
    dp1.data .= 0
    dp2.data .= 0
    return L1L2TVState{typeof(mesh),m,typeof(T),typeof(S)}(
        mesh, T, tdata, S,
	alpha1, alpha2, beta, lambda, gamma1, gamma2,
	est, g, u, p1, p2, du, dp1, dp2)
end
function OptFlowState(mesh;
	alpha1, alpha2, beta, lambda, gamma1, gamma2)
    alpha2 > 0 || beta > 0 ||
        throw(ArgumentError("operator B is singular with these parameters"))
    d = ndims_domain(mesh)
    m = 2
    Vest = FeSpace(mesh, DP0(), (1,))
    # DP1 only for optical flow
    Vg = FeSpace(mesh, DP1(), (1,))
    Vu = FeSpace(mesh, P1(), (m,))
    Vp1 = FeSpace(mesh, P1(), (1,))
    Vp2 = FeSpace(mesh, DP0(), (m, d))
    Vdg = FeSpace(mesh, DP0(), (1, d))
    # tdata will be something like nabla(fw)
    T(tdata, u) = tdata * u
    T(::typeof(adjoint), tdata, v) = tdata' * v
    S(u, nablau) = nablau
    est = FeFunction(Vest, name = "est")
    g = FeFunction(Vg, name = "g")
    u = FeFunction(Vu, name = "u")
    p1 = FeFunction(Vp1, name = "p1")
    p2 = FeFunction(Vp2, name = "p2")
    du = FeFunction(Vu; name = "du")
    dp1 = FeFunction(Vp1; name = "dp1")
    dp2 = FeFunction(Vp2; name = "dp2")
    tdata = FeFunction(Vdg, name = "tdata")
    est.data .= 0
    g.data .= 0
    u.data .= 0
    p1.data .= 0
    p2.data .= 0
    du.data .= 0
    dp1.data .= 0
    dp2.data .= 0
    return L1L2TVState{typeof(mesh),m,typeof(T),typeof(S)}(
        mesh, T, tdata, S,
	alpha1, alpha2, beta, lambda, gamma1, gamma2,
	est, g, u, p1, p2, du, dp1, dp2)
end
ndims_u_codomain(st::L1L2TVState{<:Any,m}) where m = m
"calculate l1 norm on reference triangle from langrange dofs"
function p1_ref_l1_expr(ldofs)
    u0, u1, u2 = ldofs
    return (
            (u0 - u1) * abs(u2) * u2 ^ 2 -
            ((u0 - u2) * abs(u1) * u1 ^ 2 -
                (u1 - u2) * abs(u0) * u0 ^ 2)) /
        (6 * ((u0 ^ 2 + u1 * u2) * (u1 - u2) - (u1 ^ 2 - u2 ^ 2) * u0))
end
p1_project!(p1, alpha1) = p1_project!(p1, alpha1, p1.space.element)
p1_project!(p1, alpha1, _) = throw(ArgumentError("element unsupported"))
# FIXME: probably only correct for DP0?
# TODO: finish correct discrete projection?!
function p1_project!(p1, alpha1, ::Union{DP0, P1, DP1})
    p1.data .= clamp.(p1.data, -alpha1, alpha1)
end
#function p1_project!(p1, alpha1, ::Union{P1, DP1})
#    mesh = p1.space.mesh
#    for cell in cells(mesh)
#        bind!(p1, cell)
#        F = elmap(mesh, cell)
##        ForwardDiff.gradient(p1_ref_l1_expr, p1.ldata)
#
#    end
#end
function p2_project!(p2, lambda)
    p2.space.element == DP0() || p2.space.element == DP1() ||
        p2.space.element == P1() ||
        throw(ArgumentError("element unsupported"))
    p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
    for i in axes(p2d, 2)
        p2in = norm(p2d[:, i])
        if p2in > lambda
            p2d[:, i] .*= lambda ./ p2in
        end
    end
end
function step!(st::L1L2TVState)
    T = st.T
    S = st.S
    alpha1 = st.alpha1
    alpha2 = st.alpha2
    beta = st.beta
    lambda = st.lambda
    gamma1 = st.gamma1
    gamma2 = st.gamma2
    function du_a(x_, du, nabladu, phi, nablaphi; g, u, nablau, p1, p2, tdata)
	m1 = max(gamma1, norm(T(tdata, u) - g))
	cond1 = norm(T(tdata, u) - g) > gamma1 ?
	    dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
	    zero(p1)
	a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) -
	    dot(cond1, T(tdata, phi))
	m2 = max(gamma2, norm(nablau))
	cond2 = norm(nablau) > gamma2 ?
	    dot(nablau, nabladu) / norm(nablau)^2 * p2 :
	    zero(p2)
	a2 = lambda / m2 * dot(nabladu, nablaphi) -
	    dot(cond2, nablaphi)
	aB = alpha2 * dot(T(tdata, du), T(tdata, phi)) +
	    beta * dot(S(du, nabladu), S(phi, nablaphi))
	return a1 + a2 + aB
    end
    function du_l(x_, phi, nablaphi; g, u, nablau, p1, p2, tdata)
	aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) +
	    beta * dot(S(u, nablau), S(phi, nablaphi))
	m1 = max(gamma1, norm(T(tdata, u) - g))
	p1part = alpha1 / m1 * dot(T(tdata, u) - g, T(tdata, phi))
	m2 = max(gamma2, norm(nablau))
	p2part = lambda / m2 * dot(nablau, nablaphi)
	gpart = alpha2 * dot(g, T(tdata, phi))
	return -aB - p1part - p2part + gpart
    end
    # solve du
    print("assemble ... ")
    A, b = assemble(st.du.space, du_a, du_l;
        st.g, st.u, nablau = nabla(st.u), st.p1, st.p2, st.tdata)
    print("solve ... ")
    st.du.data .= A \ b
    # solve dp1    function dp1_update(x_; g, u, p1, du, tdata)
	m1 = max(gamma1, norm(T(tdata, u) - g))
	cond = norm(T(tdata, u) - g) > gamma1 ?
	    dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
	    zero(p1)
	return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond
    end
    interpolate!(st.dp1, dp1_update;
	st.g, st.u, st.p1, st.du, st.tdata)
    # solve dp2
    function dp2_update(x_; u, nablau, p2, du, nabladu)
	m2 = max(gamma2, norm(nablau))
	cond = norm(nablau) > gamma2 ?
	    dot(nablau, nabladu) / norm(nablau)^2 * p2 :
	    zero(p2)
	return -p2 + lambda / m2 * (nablau + nabladu) - cond
    end
    interpolate!(st.dp2, dp2_update;
        st.u, nablau = nabla(st.u), st.p2, st.du, nabladu = nabla(st.du))
    # newton update
    theta = 1.
    st.u.data .+= theta * st.du.data
    st.p1.data .+= theta * st.dp1.data
    st.p2.data .+= theta * st.dp2.data
    # reproject p1, p2
    # FIXME: the p1 projection destroys the primal reconstruction
    p1_project!(st.p1, st.alpha1)
    p2_project!(st.p2, st.lambda)
end
"
2010, Chambolle and Pock: primal-dual semi-implicit algorithm
2017, Alkämper and Langer: fem dualisation
"
function step_pd!(st::L1L2TVState; sigma, tau, theta = 1.)
    # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
    # changed alpha2 -> alpha2 / 2
    # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
    if ndims_u_codomain(st) != 1 || st.lambda != 1. || st.beta != 0.
	error("unsupported parameters")
    end
    beta = tau * st.alpha1 / (1 + tau * st.alpha2)
    # u is P1
    # p2 is essentially DP0 (technically may be DP1)
    # 1. y update
    u_new = FeFunction(st.u.space) # x_bar only used here
    u_new.data .= st.u.data .+ theta .* st.du.data
    p2_next = FeFunction(st.p2.space)
    function p2_update(x_; p2, nablau)
	return p2 + sigma * nablau
    end
    interpolate!(p2_next, p2_update; st.p2, nablau = nabla(u_new))
    p2_project!(p2_next, st.lambda)
    st.dp2.data .= p2_next.data .- st.p2.data
    st.p2.data .+= st.dp2.data
    # 2.
    u_a(x, z, nablaz, phi, nablaphi; g, u, p2) =
	dot(z, phi)
    u_l(x, phi, nablaphi; u, g, p2) =	(dot(u + tau * st.alpha2 * g, phi) + tau * dot(p2, -nablaphi)) /
	    (1 + tau * st.alpha2)
    # z = 1 / (1 + tau * alpha2) *
    #   (u + tau * alpha2 * g + tau * div(p))
    z = FeFunction(st.u.space)
    A, b = assemble(z.space, u_a, u_l; st.g, st.u, st.p2)
    z.data .= A \ b
    function u_update!(u, z, g, beta)
	u.space.element::P1
	g.space.element::P1
	for i in eachindex(u.data)
	    if z.data[i] - beta >= g.data[i]
		u.data[i] = z.data[i] - beta
	    elseif z.data[i] + beta <= g.data[i]
		u.data[i] = z.data[i] + beta
	    else
		u.data[i] = g.data[i]
	    end
	end
    end
    u_update!(u_new, z, st.g, beta)
    # 3.
    # note: step-size control not implemented, since \nabla G^* is not 1/gamma continuous
    st.du.data .= u_new.data .- st.u.data
    st.u.data .= u_new.data
    #st.du.data .= z.data
    any(isnan, st.u.data) && error("encountered nan data")
    any(isnan, st.p2.data) && error("encountered nan data")
    return st
end
"
2010, Chambolle and Pock: accelerated primal-dual semi-implicit algorithm
"
function step_pd2!(st::L1L2TVState; sigma, tau, theta = 1.)
    # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
    # u is P1
    # p2 is essentially DP0 (technically may be DP1)
    # 1. y update
    u_new = FeFunction(st.u.space) # x_bar only used here
    u_new.data .= st.u.data .+ theta .* st.du.data
    p1_next = FeFunction(st.p1.space)
    p2_next = FeFunction(st.p2.space)
    function p1_update(x_; p1, g, u, tdata)
	st.alpha1 == 0 && return zero(p1)
	return (p1 + sigma * (st.T(tdata, u) - g)) /
	    (1 + st.gamma1 * sigma / st.alpha1)
    end
    interpolate!(p1_next, p1_update; st.p1, st.g, st.u, st.tdata)
    function p2_update(x_; p2, nablau)
	st.lambda == 0 && return zero(p2)
	return (p2 + sigma * nablau) /
	    (1 + st.gamma2 * sigma / st.lambda)
    end
    interpolate!(p2_next, p2_update; st.p2, nablau = nabla(u_new))
    # reproject p1, p2
    p1_project!(p1_next, st.alpha1)
    p2_project!(p2_next, st.lambda)
    st.dp1.data .= p1_next.data .- st.p1.data
    st.dp2.data .= p2_next.data .- st.p2.data
    st.p1.data .+= st.dp1.data
    st.p2.data .+= st.dp2.data
    # 2. x update
    u_a(x, w, nablaw, phi, nablaphi; g, u, p1, p2, tdata) =
	dot(w, phi) +
	tau * st.alpha2 * dot(st.T(tdata, w), st.T(tdata, phi)) +
	tau * st.beta * dot(st.S(w, nablaw), st.S(phi, nablaphi))
    u_l(x, phi, nablaphi; g, u, p1, p2, tdata) =
	dot(u, phi) - tau * (
	    dot(p1, st.T(tdata, phi)) +
	    dot(p2, nablaphi) -
	    st.alpha2 * dot(g, st.T(tdata, phi)))
    A, b = assemble(u_new.space, u_a, u_l; st.g, st.u, st.p1, st.p2, st.tdata)
    u_new.data .= A \ b
    st.du.data .= u_new.data .- st.u.data
    st.u.data .= u_new.data
    #st.du.data .= z.data
    any(isnan, st.u.data) && error("encountered nan data")
    any(isnan, st.p1.data) && error("encountered nan data")
    any(isnan, st.p2.data) && error("encountered nan data")
    return st
end
"
2004, Chambolle: dual semi-implicit algorithm
"
function step_d!(st::L1L2TVState; tau)
    # u is P1
    # p2 is essentially DP0 (technically may be DP1)
    # TODO: this might not be implementable without higher order elements
    # need grad(div(p))
    return st
end
function solve_primal!(u::FeFunction, st::L1L2TVState)
    u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
	st.alpha2 * dot(st.T(tdata, u), st.T(tdata, phi)) +
	    st.beta * dot(st.S(u, nablau), st.S(phi, nablaphi))
    u_l(x, phi, nablaphi; g, p1, p2, tdata) =
	-dot(p1, st.T(tdata, phi)) - dot(p2, nablaphi) +
	    st.alpha2 * dot(g, st.T(tdata, phi))
    # u = -B^{-1} * (T^* p_1 + nabla^* p_2 - alpha2 * T^* g)
    # solve:
    # <B u, phi> = -<p_1, T phi> - <p_2, nablaphi> + alpha_2 * <g, T phi>
    A, b = assemble(u.space, u_a, u_l; st.g, st.p1, st.p2, st.tdata)
    u.data .= A \ b
end
huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
# TODO: finish!
function refine_and_estimate_pd(st::L1L2TVState)
    # globally refine
    marked_cells = Set(axes(st.mesh.cells, 2))
    mesh_new, fs_new = refine(st.mesh, marked_cells;
        st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
    st_new = L1L2TVState(st; mesh = mesh_new,
        fs_new.est, fs_new.g, fs_new.u, fs_new.p1, fs_new.p2,
        fs_new.du, fs_new.dp1, fs_new.dp2)
    # compute primal dual error indicators
    estimate_pd!(st_new)
    # transfer data to old state
    #
end
# this computes the primal-dual error indicator which is not really useful
# if not computed on a finer mesh than `u` was solved on and has problems with
# alpha1 > 0
function estimate_pd!(st::L1L2TVState)
    function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
	alpha1part =
            st.alpha1 * huber(norm(st.T(tdata, u) - g), st.gamma1) -
	    dot(st.T(tdata, u) - g, p1) +
            (iszero(st.alpha1) ? 0. :
                st.gamma1 / (2 * st.alpha1) * norm(p1)^2)
	lambdapart =
            st.lambda * huber(norm(nablau), st.gamma2) -
	    dot(nablau, p2) +
            (iszero(st.lambda) ? 0. :
                st.gamma2 / (2 * st.lambda) * norm(p2)^2)
        # avoid non-negative rounding errors
        alpha1part = max(0, alpha1part)
        lambdapart = max(0, lambdapart)
	bpart = 1 / 2 * (
	    st.alpha2 * dot(st.T(tdata, w - u), st.T(tdata, w - u)) +
	    st.beta * dot(st.S(w, nablaw) - st.S(u, nablau),
                st.S(w, nablaw) - st.S(u, nablau)))
        res = alpha1part + lambdapart + bpart
        @assert isfinite(res)
	return res
    end
    w = FeFunction(st.u.space)
    solve_primal!(w, st)
    #w.data .= .-w.data
    # TODO: find better name: is actually a cell-wise integration
    project!(st.est, estf; st.g, st.u, st.p1, st.p2,
        nablau = nabla(st.u), w, nablaw = nabla(w), st.tdata)
    st.est.data .= sqrt.(st.est.data)
end
# this computes the residual error indicator which has missing theoretical
# support for alpha1 > 0
# TODO: finish!
function estimate_res!(st::L1L2TVState)
    # FIXME: dirty hack
    isSnabla = st.S(0, 1) == 1
    norm2(v) = dot(v, v)
    cellf(hcell; g, u, nablau, p1, p2, tdata) =
         st.alpha2 *
            st.T(adjoint, tdata, st.alpha2 * st.T(tdata, u) - g) +
            st.T(adjoint, tdata, p1)
    facetf(hfacet, n; g, u, nablau, p1, p2, tdata) =
        # norm2 is calculated later after both facet contributions are
        # consolidated
        n' * (st.beta * st.S(zero(u), nablau) + p2)
    # manual method
    mesh = st.mesh
    fs = (;st.g, st.u, nablau = nabla(st.u), st.p1, st.p2, st.tdata)
    space = st.est.space
    facetV = Dict{Tuple{Int, Int}, MVector{ndims_u_codomain(st), Float64}}()
    for cell in cells(mesh)        A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
            view(mesh.vertices, :, view(mesh.cells, :, cell)))
	for f in fs
	    bind!(f, cell)
	end
        # cell contribution
        hcell = diam(mesh, cell)
	delmap = jacobian(elmap(mesh, cell), SA[0., 0.])
	intel = abs(det(delmap))
	centroid = SArray{Tuple{ndims_domain(mesh)}}(mean(A, dims = 2))
	lcentroid = SA[1/3, 1/3]
	opvalues = map(f -> evaluate(f, lcentroid), fs)
        cellres = (isSnabla ? hcell^2 : hcell) *
            norm2(cellf(hcell; opvalues...)) .* intel
	gdofs = space.dofmap[:, 1, cell]
	st.est.data[gdofs] .= cellres
        # facet contributions
        cross(v) = SA[-v[2], v[1]]
        for (i, j) in ((i, mod1(i + 1, 3)) for i in 1:3)
            fi = mesh.cells[i, cell]
            fj = mesh.cells[j, cell]
            v = A[:, j] - A[:, i]
            hfacet = norm(v)
            normal = -normalize(cross(v))
            intel = hfacet
            # we don't need to recompute values since only piecewise
            # constant data is used in `facetf`
            # we use sqrt here since this value will be squared later
            facetres = (isSnabla ? sqrt(hfacet) : inv(sqrt(hfacet))) *
                facetf(hfacet, normal; opvalues...) .* sqrt(intel)
            # average facet contributions
            v = SVector(facetres)
            facetV[(fi, fj)] = get(facetV, (fi, fj), zero(v)) + SVector(v)
            facetV[(fj, fi)] = get(facetV, (fj, fi), zero(v)) + SVector(v)
        end
    end
    # add facet contributions to cells
    for cell in cells(mesh)
        for (i, j) in ((i, mod1(i + 1, 3)) for i in 1:3)
            fi = mesh.cells[i, cell]
            fj = mesh.cells[j, cell]
            gdofs = space.dofmap[:, 1, cell]
            st.est.data[gdofs] .+= norm2(facetV[(fi, fj)])
        end
    end
    st.est.data .= sqrt.(st.est.data)
end
estimate_error(st::L1L2TVState) =
    sqrt(sum(x -> x^2, st.est.data) / area(st.mesh))
# minimal Dörfler marking
function mark(st::L1L2TVState; theta=0.5)
    n = ncells(st.mesh)
    esttotal = sum(st.est.data)
    cellerrors = collect(pairs(st.est.data))
    cellerrors_sorted = sort(cellerrors; lt = (x, y) -> x.second > y.second)
    marked_cells = Int[]
    estacc = 0.
    for (cell, error) in cellerrors_sorted
	estacc >= theta * esttotal &&  break
	push!(marked_cells, cell)
	estacc += error
    end
    return marked_cells
end
function output(st::L1L2TVState, filename, fs...)
    println("save \"$filename\" ... ")
    vtk = vtk_mesh(filename, st.mesh)
    vtk_append!(vtk, fs...)
    vtk_save(vtk)
    return vtk
end
function primal_energy(st::L1L2TVState)
    function integrand(x; g, u, nablau, tdata)
	return st.alpha1 * huber(norm(st.T(tdata, u) - g), st.gamma1) +
	    st.alpha2 / 2 * norm(st.T(tdata, u) - g)^2 +
	    st.beta / 2 * norm(st.S(u, nablau))^2 +
	    st.lambda * huber(norm(nablau), st.gamma2)
    end
    return integrate(st.mesh, integrand; st.g, st.u,
        nablau = nabla(st.u), st.tdata)
end
function dual_energy(st::L1L2TVState)
    # primal reconstruction
    w = FeFunction(st.u.space)
    solve_primal!(w, st)
    # 1 / 2 * \|w\|_B^2 - alpha2 / 2 * \|g\| + <g, p_1> +
    #   gamma1 / 2 / alpha1 * \|p_1\|^2 + gamma2 / 2 / lambda * \|p_2\|^2
    function integrand(x; g, tdata, w, nablaw, p1, p2)
        return 1 / 2 * (
                st.alpha2 * dot(st.T(tdata, w), st.T(tdata, w)) +
                st.beta * dot(st.S(w, nablaw), st.S(w, nablaw))) +
            - st.alpha2 / 2 * dot(g, g) +
            dot(g, p1) +
            (iszero(st.alpha1) ? 0 : st.gamma1 / 2 / st.alpha1 * dot(p1, p1)) +
            (iszero(st.lambda) ? 0 : st.gamma2 / 2 / st.lambda * dot(p2, p2))
    end
    return integrate(st.mesh, integrand; st.g, st.tdata,
        w, nablaw = nabla(w), st.p1, st.p2)
end
norm_l2(f) = norm_l2((x; f) -> f, f.space.mesh; f)
function norm_l2(f::Function, mesh; params...)
    f_l2 = function(x; params...)
        fx = f(x; params...)
        return dot(fx, fx)
    end
    res = integrate(mesh, f_l2; params...)
    return sqrt(res)
end
norm_step(st::L1L2TVState) =
    sqrt((norm_l2(st.du)^2 + norm_l2(st.dp1)^2 + norm_l2(st.dp2)^2) / area(st.mesh))
function norm_residual(st::L1L2TVState)
    w = FeFunction(st.u.space)
    solve_primal!(w, st)    w.data .-= st.u.data
    upart2 = norm_l2(w)^2
    function integrand(x; g, u, nablau, p1, p2, tdata)
	p1part = p1 * max(st.gamma1, norm(st.T(tdata, u) - g)) -
	    st.alpha1 * (st.T(tdata, u) - g)
	p2part = p2 * max(st.gamma2, norm(nablau)) -
	    st.lambda * nablau
	return norm(p1part)^2 + norm(p2part)^2
    end
    ppart2 = integrate(st.mesh, integrand; st.g, st.u,
        nablau = nabla(st.u), st.p1, st.p2, st.tdata)
    return sqrt(upart2 + ppart2)
end
function denoise(img; params...)
    m = 1
    img = from_img(img) # coord flip
    #mesh = init_grid(img; type=:vertex)
    mesh = init_grid(img, 5, 5)
    T(tdata, u) = u
    S(u, nablau) = u
    st = L1L2TVState{m}(mesh; T, tdata = nothing, S, params...)
    project_l2_lagrange!(st.g, img)
    #interpolate!(st.g, x -> evaluate_bilinear(img, x))
    #m = (size(img) .- 1) ./ 2 .+ 1
    #interpolate!(st.g, x -> norm(x .- m) < norm(m .- 1) / 3)
    save_denoise(st, i) =
	output(st, "output/$(st.name)_$(lpad(i, 5, '0')).vtu",
	    st.g, st.u, st.p1, st.p2, st.est)
    pvd = paraview_collection("output/$(st.name).pvd")
    pvd[0] = save_denoise(st, 0)
    k = 0
    println("primal energy: $(primal_energy(st))")
    while true
	while true
	    k += 1
	    step!(st)
	    estimate_pd!(st)
	    pvd[k] = save_denoise(st, k)
	    println()
	    norm_step_ = norm_step(st)
	    norm_residual_ = norm_residual(st)
	    println("ndofs: $(ndofs(st.u.space)), est: $(norm_l2(st.est)))")
	    println("primal energy: $(primal_energy(st))")
	    println("norm_step: $(norm_step_)")
	    println("norm_residual: $(norm_residual_)")
            norm_step_ <= 1e-1 && break
	end
	marked_cells = mark(st; theta = 0.5)
	println("refining ...")
	st, _ = refine(st, marked_cells)
	test_mesh(st.mesh)
	project_l2_lagrange!(st.g, img)
	k >= 100 && break
    end
    vtk_save(pvd)
    return stend
function denoise_pd(ctx)
    params = ctx.params
    m = 1
    mesh = ctx.params.mesh
    T(tdata, u) = u
    S(u, nablau) = u
    st = L1L2TVState{m}(mesh;
        T, tdata = nothing, S,
        params.alpha1, params.alpha2, params.lambda, params.beta,
        params.gamma1, params.gamma2)
    # semi-implicit primal dual parameters
    mu = st.alpha2 + st.beta # T = I, S = I
    #mu /= 100 # kind of arbitrary?
    tau = 1e-1
    L = 100
    sigma = inv(tau * L^2)
    theta = 1.
    #project_l2_lagrange!(st.g, ctx.params.g_arr)
    interpolate!(st.g, x -> evaluate_bilinear(ctx.params.g_arr, x))
    st.u.data .= st.g.data
    save_denoise(st, i) =
	output(st, "output/$(st.name)_$(lpad(i, 5, '0')).vtu",
	    st.g, st.u, st.p1, st.p2)
    log!(x::Nothing; kwargs...) = x
    function log!(df::DataFrame; kwargs...)
        row = NamedTuple(kwargs)
        push!(df, row)
        #println(df)
	println(row)
    end
    #pvd = paraview_collection("output/$(st.name).pvd")
    #pvd[0] = save_denoise(st, 0)
    k = 0
    println("primal energy: $(primal_energy(st))")
    while true
	k += 1
	if ctx.params.algorithm == :pd1
	    # no step size control
	    step_pd2!(st; sigma, tau, theta)
	elseif ctx.params.algorithm == :pd2
	    theta = 1 / sqrt(1 + 2 * mu * tau)
	    tau *= theta
	    sigma /= theta
	    step_pd2!(st; sigma, tau, theta)
	elseif ctx.params.algorithm == :newton
	    step!(st)
	end
	#pvd[k] = save_denoise(st, k)
	domain_factor = 1 / sqrt(area(mesh))
	norm_step_ = norm_step(st) * domain_factor
        #estimate_pd!(st)
        log!(ctx.params.df; k,
            norm_step = norm_step_,
            #est = estimate_error(st),
            primal_energy = primal_energy(st))
	#norm_residual_ < params.tol && norm_step_ < params.tol && break
	haskey(params, :tol) && norm_step_ < params.tol && break
        haskey(params, :max_iters) && k >= params.max_iters && break
    end
    #pvd[1] = save_denoise(st, 1)
    #vtk_save(pvd)
    return st
end
function experiment_convergence_rate(ctx)
    img = loadimg(joinpath(ctx.indir, "input.png"))
    #img = [0. 0.; 1. 0.]
    g_arr = from_img(img)
    mesh = init_grid(g_arr;)
    algparams = (
        alpha1=0., alpha2=30., lambda=1., beta=0.,
        gamma1=1e-2, gamma2=1e-3,
        tol = 1e-10,
        max_iters = 10000,
    )
    algctx = Util.Context(ctx; g_arr, mesh, algparams...)
    df1 = DataFrame()
    df2 = DataFrame()
    df3 = DataFrame()
    st1 = denoise_pd(Util.Context(algctx; algorithm = :pd1, df = df1));
    st2 = denoise_pd(Util.Context(algctx; algorithm = :pd2, df = df2));
    st3 = denoise_pd(Util.Context(algctx; algorithm = :newton, df = df3));
    energy_min = min(minimum(df1.primal_energy), minimum(df2.primal_energy),
		     minimum(df3.primal_energy))
    df1[!, :energy_min] .= energy_min
    df2[!, :energy_min] .= energy_min
    df3[!, :energy_min] .= energy_min
    CSV.write(joinpath(ctx.outdir, "semi-implicit.csv"),
        logfilter(df1))
    CSV.write(joinpath(ctx.outdir, "semi-implicit-accelerated.csv"),
        logfilter(df2))
    CSV.write(joinpath(ctx.outdir, "newton.csv"),
        logfilter(df3))
    saveimg(joinpath(ctx.outdir, "input.png"),
        img)
    saveimg(joinpath(ctx.outdir, "semi-implicit.png"),
        to_img(sample(st1.u)))
    saveimg(joinpath(ctx.outdir, "semi-implicit-accelerated.png"),
        to_img(sample(st2.u)))
    saveimg(joinpath(ctx.outdir, "newton.png"),
        to_img(sample(st3.u)))
    savedata(joinpath(ctx.outdir, "data.tex");
        energy_min, algparams...)
end
function denoise_approximation(ctx)
    domain_factor = 1 / sqrt(area(ctx.params.mesh))
    m = 1
    T(tdata, u) = u
    S(u, nablau) = u
    #S(u, nablau) = nablau
    st = L1L2TVState{m}(ctx.params.mesh;
        T, tdata = nothing, S,
        ctx.params.alpha1, ctx.params.alpha2,
        ctx.params.lambda, ctx.params.beta,        ctx.params.gamma1, ctx.params.gamma2)
    # primal reconstruction
    w = FeFunction(st.u.space, name = "w")
    w.data .= 0
    #project_l2_lagrange!(st.g, img)
    interpolate!(st.g, x -> evaluate_bilinear(ctx.params.g_arr, x))
    st.u.data .= st.g.data
    #st.u.data .= rand(size(st.u.data))
    save_vtk(st, i) =
        output(st,
            joinpath(ctx.outdir, "$(ctx.params.name)_$(lpad(i, 5, '0')).vtu"),
	    st.g, st.u, st.p1, st.p2, st.est, w)
    log!(x::Nothing; kwargs...) = x
    function log!(df::DataFrame; kwargs...)
        row = NamedTuple(kwargs)
        push!(df, row)
        println(df)
	#println(row)
    end
    pvd_path = joinpath(ctx.outdir, "$(ctx.params.name).pvd")
    pvd = paraview_collection(pvd_path) do pvd
        pvd[0] = save_vtk(st, 0)
        println("primal energy: $(primal_energy(st))")
        # initial refinement
        for _ in 1:0
            marked_cells = Set(axes(st.mesh.cells, 2))
            mesh, fs = refine(st.mesh, marked_cells;
                st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2, w)
            st = L1L2TVState(st; mesh,
                fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
            w = fs.w
        end
        #marked_cells = Set(axes(st.mesh.cells, 2))
        #mesh2, fs = refine(st.mesh, marked_cells;
        #    st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
        #st2 = L1L2TVState(st; mesh = mesh2,
        #    fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
        k = 0
        while true && k < 50
            step!(st)
            solve_primal!(w, st)
            norm_step_ = norm_step(st) * domain_factor
            k += 1
            norm_residual_ = norm_residual(st) * domain_factor
            estimate_pd!(st)
            pvd[k] = save_vtk(st, k)
            log!(ctx.params.df; k,
                ndofs = ndofs(st.u.space),
                hmax = mesh_size(st.mesh),
                norm_step = norm_step_, norm_residual = norm_residual_,
                primal_energy = primal_energy(st),
                dual_energy = dual_energy(st),
                est = estimate_error(st),
            )
            if haskey(ctx.params, :tol) && norm_step_ < ctx.params.tol
                k += 1
                norm_residual_ = norm_residual(st) * domain_factor
                estimate_pd!(st)
                pvd[k] = save_vtk(st, k)
                log!(ctx.params.df; k,
                    ndofs = ndofs(st.u.space),
                    hmax = mesh_size(st.mesh),
                    norm_step = norm_step_, norm_residual = norm_residual_,
                    primal_energy = primal_energy(st),
                    dual_energy = dual_energy(st),
                    est = estimate_error(st),
                )
                marked_cells = ctx.params.adaptive ?
                    mark(st; theta = 0.5) : # estimator + dörfler
                    Set(axes(st.mesh.cells, 2))
                mesh, fs = refine(st.mesh, marked_cells;
                    st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2, w)
                st = L1L2TVState(st; mesh,
                    fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
                w = fs.w
                norm_residual_ = norm_residual(st) * domain_factor
                estimate_pd!(st)
                pvd[k] = save_vtk(st, k)
                log!(ctx.params.df; k,
                    ndofs = ndofs(st.u.space),
                    hmax = mesh_size(st.mesh),
                    norm_step = norm_step_, norm_residual = norm_residual_,
                    primal_energy = primal_energy(st),
                    dual_energy = dual_energy(st),
                    est = estimate_error(st),
                )
            end
        end
    end
    return st
end
function experiment_approximation(ctx)
    #g_arr = from_img([1. 0.; 0. 0.])
    g_arr = from_img([0. 0.; 1. 0.])
    mesh = init_grid(g_arr;)
    df = DataFrame()
    denoise_approximation(Util.Context(ctx; name = "test", df,
        g_arr, mesh,
        #alpha1 = 0., alpha2 = 30., lambda = 1., beta = 0.,
        #alpha1 = 0.5, alpha2 = 0., lambda = 0., beta = 1.,
        alpha1 = 1., alpha2 = 0., lambda = 1., beta = 1e-5,
        gamma1 = 1e-3, gamma2 = 1e-3,
        tol = 1e-10, adaptive = true,
    ))
end
function inpaint(img, imgmask; params...)
    size(img) == size(imgmask) ||
	throw(ArgumentError("non-matching dimensions"))
    m = 1
    img = from_img(img) # coord flip
    imgmask = from_img(imgmask) # coord flip
    mesh = init_grid(img; type=:vertex)
    # inpaint specific stuff
    Vg = FeSpace(mesh, P1(), (1,))
    mask = FeFunction(Vg, name = "mask")
    T(tdata, u) = isone(tdata[begin]) ? u : zero(u)
    S(u, nablau) = u
    st = L1L2TVState{m}(mesh; T, tdata = mask, S, params...)
    # FIXME: currently dual grid only
    interpolate!(mask, x -> imgmask[round.(Int, x)...])
    #interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1)
    interpolate!(st.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] :  0.)
    m = (size(img) .- 1) ./ 2 .+ 1
    interpolate!(st.g, x -> norm(x .- m) < norm(m) / 3)
    save_inpaint(i) =
	output(st, "output/$(st.name)_$(lpad(i, 5, '0')).vtu",
	    st.g, st.u, st.p1, st.p2, st.est, mask)
    pvd = paraview_collection("output/$(st.name).pvd")
    pvd[0] = save_inpaint(0)
    for i in 1:3
	step!(st)
	estimate_pd!(st)
	pvd[i] = save_inpaint(i)
	println()
    end
    return st
end
function optflow(ctx)
    size(ctx.params.imgf0) == size(ctx.params.imgf1) ||
	throw(ArgumentError("non-matching image sizes"))
    project_image! = project_l2_lagrange!
    eps_newton = 1e-3 # cauchy criterion for inner newton loop
    eps_warp = 0.05
    n_refine = 6
    # convert to cartesian coordinates
    imgf0 = from_img(ctx.params.imgf0)
    imgf1 = from_img(ctx.params.imgf1)
    mesh = init_grid(imgf0, floor.(Int, size(imgf0) ./ 2^(n_refine / 2))...)
    mesh_area = area(mesh)
    # optflow specific stuff
    Vg = FeSpace(mesh, P1(), (1,))
    f0 = FeFunction(Vg, name = "f0")
    f1 = FeFunction(Vg, name = "f1")
    fw = FeFunction(Vg, name = "fw")
    st = OptFlowState(mesh;
        ctx.params.alpha1, ctx.params.alpha2,
        ctx.params.lambda, ctx.params.beta,
        ctx.params.gamma1, ctx.params.gamma2)
    function warp!()
        println("warp and reproject ...")
        # warp image into imgfw / fw
        imgfw = warp_backwards(imgf1, sample(st.u))
        project_image!(fw, imgfw)
        # recompute optflow operator T based on u0 and fw
        function tdata_optflow(x; u0_deriv, nablafw)
            #res = nablafw / (I + u0_deriv')
            res = nablafw
            all(isfinite, res) || throw(DivideError("singular optflow matrix"))
            return res
        end
        interpolate!(st.tdata, tdata_optflow;
            u0_deriv = nabla(st.u), nablafw = nabla(fw))
        # recompute optflow data g
        # note that it is important here that the space of g is general enough
        # to avoid any information loss that could screw up image warping        g_optflow(x; u0, f0, fw, tdata) =
            st.T(tdata, u0) - (fw - f0)
        interpolate!(st.g, g_optflow; u0 = st.u, f0, fw, st.tdata)
    end
    function interpolate_image_data!()
        println("interpolate image data ...")
        project_image!(f0, imgf0)
        project_image!(f1, imgf1)
    end
    calc_fdist() = norm_l2((x; f0, fw) -> f0 - fw, mesh; f0, fw)
    save_step(i) =
        output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"),
	    st.g, st.u, st.p1, st.p2, st.est, f0, f1, fw)
    pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd
        interpolate_image_data!()
        warp!() # in the first step just to fill st.g
        pvd[0] = save_step(0)
        i = 0
        k_newton = 0
        k_warp = 0
        k_refine = 0
        while true
            # interior newton
            k_newton += 1
            step!(st)
            norm_step_ = norm_step(st) / sqrt(mesh_area)
            println("norm_step = $norm_step_")
            # interior newton stop criterion
            norm_step_ > eps_newton && k_newton < 10 && continue
            k_newton = 0
            # plot
            i += 1
            display(plot(colorflow(to_img(sample(st.u)); ctx.params.maxflow)))
            pvd[i] = save_step(i)
            # warp
            k_warp += 1
            fdist0 = calc_fdist()
            warp!()
            fdist1 = calc_fdist()
            rel_datachange = (fdist1 - fdist0) / fdist0
            println("rel data change: $(rel_datachange)")
            # warping stop criterion
            rel_datachange < -eps_warp && continue
            # refinement stop criterion
            k_refine += 1
            k_refine > n_refine && break
            println("refine ...")
            estimate_res!(st)
            marked_cells = mark(st; theta = 0.5)
            #marked_cells = Set(axes(mesh.cells, 2))
            mesh, fs = refine(mesh, marked_cells;
                st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2,
                st.tdata, f0, f1, fw)
            st = L1L2TVState(st; mesh, fs.tdata,
                fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
            f0, f1, fw = (fs.f0, fs.f1, fs.fw)
            interpolate_image_data!()        end
    end
    #CSV.write(joinpath(ctx.outdir, "energies.csv"), df)
    u_sampled = sample(st.u)
    saveimg(joinpath(ctx.outdir, "f0.png"), to_img(imgf0))
    saveimg(joinpath(ctx.outdir, "f1.png"), to_img(imgf1))
    saveimgdiff(joinpath(ctx.outdir, "g.png"), to_img(imgf0), to_img(imgf1))
    saveimg(joinpath(ctx.outdir, "output.png"), colorflow(to_img(u_sampled); ctx.params.maxflow))
    imgfw = warp_backwards(imgf1, u_sampled)
    saveimg(joinpath(ctx.outdir, "fw.png"), to_img(imgfw))
    savedata(joinpath(ctx.outdir, "data.tex");
        eps_newton, eps_warp, n_refine,
        st.alpha1, st.alpha2, st.lambda, st.beta, st.gamma1, st.gamma2,
        width=size(u_sampled, 1), height=size(u_sampled, 2))
    return st
end
function experiment_optflow_middlebury(ctx)
    imgf0 = loadimg(joinpath(ctx.indir, "frame10.png"))
    imgf1 = loadimg(joinpath(ctx.indir, "frame11.png"))
    gtflow = FileIO.load(joinpath(ctx.indir, "flow10.flo"))
    maxflow = OpticalFlowUtils.maxflow(gtflow)
    ctx = Util.Context(ctx; imgf0, imgf1, maxflow)
    saveimg(joinpath(ctx.outdir, "ground_truth.png"), colorflow(gtflow; maxflow))
    return optflow(ctx)
end
function experiment_optflow_middlebury_all(ctx)
    for example in ["Dimetrodon", "Grove2", "Grove3", "Hydrangea",
            "RubberWhale", "Urban2", "Urban3", "Venus"]
        ctx(experiment_optflow_middlebury, example;
            alpha1 = 10., alpha2 = 0., lambda = 1., beta = 1e-5,
            gamma1 = 1e-3, gamma2 = 1e-3)
    end
end
function test_image(n = 2^6; supersample_factor = 16)
    q = supersample_factor
    imgs = zeros(n, n)
    f((a, b)) = (sin(0.5 * 6 / (a^2 + b^2 + 1e-1)) + 1) / 2
    for I in CartesianIndices((n, n))
        for J in CartesianIndices((q, q))
            imgs[I] += f((Tuple(I * q - J) .+ 0.5) ./ (n * q))
        end
        imgs[I] /= q^2
    end
    return imgs
end
function experiment_image_mesh_interpolation(ctx)
    imgf = from_img(loadimg(joinpath(ctx.indir, "input.png")))
    df_psnr = DataFrame()
    df_ssim = DataFrame()
    for mesh_size in (32, 16, 13)
        mesh = init_grid(imgf, mesh_size)
        space = FeSpace(mesh, P1(), (1,))
        u = FeFunction(space)
        # all methods use bilinear interpolation for image evaluations
        methods = [
            "nodal" => interpolate!,            "l2_lagrange" => project_l2_lagrange!,
            #"clement" => projec_clement!,
            "qi_lagrange" => project_qi_lagrange!,
            #"qi_lagrange_avg" => project_qi_lagrange!,
            "l2_pixel" => project_l2_pixel!,
        ]
        qualities = map(methods) do (method, f!)
            u.data .= false
            f!(u, imgf)
            save_csv(joinpath(ctx.outdir, "$(mesh_size)_$(method).csv"), u)
            imgu = sample(u)
            saveimg(joinpath(ctx.outdir, "$(mesh_size)_$(method).png"), to_img(imgu))
            return method => (
                psnr = assess_psnr(imgu, imgf),
                ssim = assess_ssim(imgu, imgf))
        end
        psnr = map(x -> Symbol(first(x)) => last(x).psnr, qualities)
        ssim = map(x -> Symbol(first(x)) => last(x).ssim, qualities)
        push!(df_psnr, (;mesh_size, psnr...))
        push!(df_ssim, (;mesh_size, ssim...))
    end
    CSV.write(joinpath(ctx.outdir, "psnr.csv"), df_psnr)
    CSV.write(joinpath(ctx.outdir, "ssim.csv"), df_ssim)
    #savedata(joinpath(ctx.outdir, "data.tex"); energy_min, algparams...)
end