From 37c041810d9edf8492f266bd35883479cb53d167 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Fri, 9 Jul 2021 13:45:50 +0200
Subject: [PATCH] refactor code

---
 src/function.jl |  14 ++-
 src/image.jl    |   2 +-
 src/mesh.jl     |   2 +-
 src/run.jl      | 274 ++++++++++++++++++++++++++++++++++++++++++++++--
 4 files changed, 279 insertions(+), 13 deletions(-)

diff --git a/src/function.jl b/src/function.jl
index 192f9b0..554792b 100644
--- a/src/function.jl
+++ b/src/function.jl
@@ -78,7 +78,7 @@ end
 # evaluate at local point
 @inline function evaluate(space::FeSpace, ldofs, xloc)
     bv = evaluate_basis(space.element, xloc)
-    ldofs_ = SArray{Tuple{prod(space.size),ndofs(space.element)}}(ldofs)
+    ldofs_ = SArray{Tuple{prod(space.size), ndofs(space.element)}}(ldofs)
     v = ldofs_ * bv
     return SArray{Tuple{space.size...}}(v)
 end
@@ -105,7 +105,8 @@ end
 Base.show(io::IO, ::MIME"text/plain", f::FeFunction) =
     print("$(nameof(typeof(f))), size $(f.space.size) with $(length(f.data)) dofs")
 
-interpolate!(dst::FeFunction, expr::Function; params...) = interpolate!(dst, dst.space.element, expr; params...)
+interpolate!(dst::FeFunction, expr::Function; params...) =
+    interpolate!(dst, dst.space.element, expr; params...)
 
 
 myvec(x) = vec(x)
@@ -121,7 +122,7 @@ function interpolate!(dst::FeFunction, ::P1, expr::Function; params...)
 	end
 	for eldof in axes(mesh.cells, 1)
 	    xid = mesh.cells[eldof, cell]
-	    x = mesh.vertices[:, xid]
+	    x = SArray{Tuple{ndims_domain(mesh)}}(mesh.vertices[:, xid])
 	    xloc = SA[0. 1. 0.; 0. 0. 1.][:, eldof]
 
 	    opvalues = map(f -> evaluate(f, xloc), params)
@@ -141,7 +142,7 @@ function interpolate!(dst::FeFunction, ::DP0, expr::Function; params...)
 	    bind!(f, cell)
 	end
 	vertices = mesh.vertices[:, mesh.cells[:, cell]]
-	centroid = reshape(mean(vertices, dims = 2), 2)
+	centroid = SArray{Tuple{ndims_domain(mesh)}}(mean(vertices, dims = 2))
 	lcentroid = SA[1/3, 1/3]
 
 	opvalues = map(f -> evaluate(f, lcentroid), params)
@@ -162,6 +163,11 @@ end
 # evaluate at local point (needs bind! call before)
 evaluate(f::FeFunction, x) = evaluate(f.space, f.ldata, x)
 
+# allow any non-function to act as a constant function
+bind!(c, cell) = nothing
+evaluate(c, xloc) = c
+
+# TODO: inherit from some abstract function type
 struct Derivative{F}
     f::F
 end
diff --git a/src/image.jl b/src/image.jl
index 969f555..ab14b37 100644
--- a/src/image.jl
+++ b/src/image.jl
@@ -11,7 +11,7 @@ function interpolate_bilinear(img, x)
 	cornerbool = Bool.(Tuple(idx))
 	λ = ifelse.(cornerbool, x .- x0, x1 .- x)
 	corner = ifelse.(cornerbool, x1, x0)
-	val += prod(λ) * eval_neumann(img, CartesianIndex(corner))
+	val += prod(λ) * eval_neumann(img, corner)
     end
     return val
 end
diff --git a/src/mesh.jl b/src/mesh.jl
index 8a03980..b5da461 100644
--- a/src/mesh.jl
+++ b/src/mesh.jl
@@ -35,7 +35,7 @@ function init_grid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
     return Mesh(vertices, cells)
 end
 
-init_grid(img::Array{<:Any, 2}, type=:vertex) =
+init_grid(img::Array{<:Any, 2}; type=:vertex) =
     type == :vertex ?
 	init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) :
 	init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5))
diff --git a/src/run.jl b/src/run.jl
index dc9a911..54d2201 100644
--- a/src/run.jl
+++ b/src/run.jl
@@ -1,8 +1,274 @@
-export myrun
+export myrun, denoise, inpaint, optflow
 
 using LinearAlgebra: norm
 
 
+struct L1L2TVContext{M,Ttype,Stype}
+    name::String
+    mesh::M
+    d::Int # = ndims_domain(mesh)
+    m::Int
+
+    T::Ttype
+    tdata
+    S::Stype
+
+    alpha1::Float64
+    alpha2::Float64
+    beta::Float64
+    lambda::Float64
+    gamma1::Float64
+    gamma2::Float64
+
+    g::FeFunction
+    u::FeFunction
+    p1::FeFunction
+    p2::FeFunction
+    du::FeFunction
+    dp1::FeFunction
+    dp2::FeFunction
+    nablau
+    nabladu
+end
+
+function L1L2TVContext(name, mesh, m; T, tdata, S,
+	alpha1, alpha2, beta, lambda, gamma1, gamma2)
+    d = ndims_domain(mesh)
+
+    Vg = FeSpace(mesh, P1(), (1,))
+    Vu = FeSpace(mesh, P1(), (m,))
+    Vp1 = FeSpace(mesh, DP0(), (1,))
+    Vp2 = FeSpace(mesh, DP1(), (m, d))
+
+    g = FeFunction(Vg, name="g")
+    u = FeFunction(Vu, name="u")
+    p1 = FeFunction(Vp1, name="p1")
+    p2 = FeFunction(Vp2, name="p2")
+    du = FeFunction(Vu)
+    dp1 = FeFunction(Vp1)
+    dp2 = FeFunction(Vp2)
+    nablau = nabla(u)
+    nabladu = nabla(du)
+
+    g.data .= 0
+    u.data .= 0
+    p1.data .= 0
+    p2.data .= 0
+    du.data .= 0
+    dp1.data .= 0
+    dp2.data .= 0
+
+    return L1L2TVContext(name, mesh, d, m, T, tdata, S,
+	alpha1, alpha2, beta, lambda, gamma1, gamma2,
+	g, u, p1, p2, du, dp1, dp2, nablau, nabladu)
+end
+
+function step!(ctx::L1L2TVContext)
+    T = ctx.T
+    S = ctx.S
+    alpha1 = ctx.alpha1
+    alpha2 = ctx.alpha2
+    beta = ctx.beta
+    lambda = ctx.lambda
+    gamma1 = ctx.gamma1
+    gamma2 = ctx.gamma2
+
+    function du_a(x, du, nabladu, phi, nablaphi; g, u, nablau, p1, p2, tdata)
+	m1 = max(gamma1, norm(T(tdata, u) - g))
+	cond1 = norm(T(tdata, u) - g) > gamma1 ?
+	    dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
+	    zeros(size(p1))
+	a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) -
+	    dot(cond1, T(tdata, phi))
+
+	m2 = max(gamma2, norm(nablau))
+	cond2 = norm(nablau) > gamma2 ?
+	    dot(nablau, nabladu) / norm(nablau)^2 * p2 :
+	    zeros(size(p2))
+	a2 = lambda / m2 * dot(nabladu, nablaphi) -
+	    dot(cond2, nablaphi)
+
+	aB = alpha2 * dot(T(tdata, du), T(tdata, phi)) +
+	    beta * dot(S(du, nabladu), S(phi, nablaphi))
+
+	return a1 + a2 + aB
+    end
+
+    function du_l(x, phi, nablaphi; g, u, nablau, tdata)
+	aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) +
+	    beta * dot(S(u, nablau), S(phi, nablaphi))
+	m1 = max(gamma1, norm(T(tdata, u) - g))
+	p1part = alpha1 / m1 * dot(T(tdata, u) - g, T(tdata, phi))
+	m2 = max(gamma2, norm(nablau))
+	p2part = lambda / m2 * dot(nablau, nablaphi)
+	gpart = alpha2 * dot(g, T(tdata, phi))
+
+	return -aB - p1part - p2part + gpart
+    end
+
+    # solve du
+    print("assemble ... ")
+    A = assemble(ctx.du.space, du_a; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata)
+    b = assemble_rhs(ctx.du.space, du_l; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
+    print("solve ... ")
+    ctx.du.data .= A \ b
+
+    # solve dp1
+    function dp1_update(x; g, u, p1, du, tdata)
+	m1 = max(gamma1, norm(T(tdata, u) - g))
+	cond = norm(T(tdata, u) - g) > gamma1 ?
+	    dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
+	    zeros(size(p1))
+	return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond
+    end
+    interpolate!(ctx.dp1, dp1_update; ctx.g, ctx.u, ctx.p1, ctx.du, ctx.tdata)
+
+    # solve dp2
+    function dp2_update(x; u, nablau, p2, du, nabladu)
+	m2 = max(gamma2, norm(nablau))
+	cond = norm(nablau) > gamma2 ?
+	    dot(nablau, nabladu) / norm(nablau)^2 * p2 :
+	    zeros(size(p2))
+	return -p2 + lambda / m2 * (nablau + nabladu) - cond
+    end
+    interpolate!(ctx.dp2, dp2_update; ctx.u, ctx.nablau, ctx.p2, ctx.du, ctx.nabladu)
+
+    # newton update
+    ctx.u.data .+= ctx.du.data
+    ctx.p1.data .+= ctx.dp1.data
+    ctx.p2.data .+= ctx.dp2.data
+
+    # reproject p1
+    function p1_project!(p1, alpha1)
+	p1.space.element::DP0
+	p1.data .= clamp.(p1.data, -alpha1, alpha1)
+    end
+    p1_project!(ctx.p1, ctx.alpha1)
+    # reproject p2
+    function p2_project!(p2, lambda)
+	p2.space.element::DP1
+	p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
+	for i in axes(p2d, 2)
+	    p2in = norm(p2d[:, i])
+	    if p2in > lambda
+		p2d[:, i] .*= lambda ./ p2in
+	    end
+	end
+    end
+    p2_project!(ctx.p2, ctx.lambda)
+end
+
+function save(ctx::L1L2TVContext, filename, fs...)
+    print("save ... ")
+    vtk = vtk_mesh(filename, ctx.mesh)
+    vtk_append!(vtk, fs...)
+    vtk_save(vtk)
+    return vtk
+end
+
+function denoise(img; name, params...)
+    m = 1
+    mesh = init_grid(img; type=:vertex)
+
+    T(tdata, u) = u
+    S(u, nablau) = u
+
+    ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
+
+    interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
+
+    save_denoise(i) =
+	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	    ctx.g, ctx.u, ctx.p1, ctx.p2)
+
+    pvd = paraview_collection("$(ctx.name).pvd")
+    pvd[0] = save_denoise(0)
+    for i in 1:10
+	step!(ctx)
+	pvd[i] = save_denoise(i)
+	println()
+    end
+end
+
+function inpaint(img, imgmask; name, params...)
+    size(img) == size(imgmask) ||
+	throw(ArgumentError("non-matching dimensions"))
+
+    m = 1
+    mesh = init_grid(img; type=:vertex)
+
+    # inpaint specific stuff
+    Vg = FeSpace(mesh, P1(), (1,))
+    mask = FeFunction(Vg, name="mask")
+
+    T(tdata, u) = iszero(tdata) ? zero(u) : u
+    S(u, nablau) = u
+
+    ctx = L1L2TVContext(name, mesh, m; T, tdata = mask, S, params...)
+
+    # FIXME: currently dual grid only
+    interpolate!(mask, x -> imgmask[round.(Int, x)...])
+    #interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1)
+    interpolate!(ctx.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] :  0.)
+
+    save_inpaint(i) =
+	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, mask)
+
+    pvd = paraview_collection("$(ctx.name).pvd")
+    pvd[0] = save_inpaint(0)
+    for i in 1:10
+	step!(ctx)
+	pvd[i] = save_inpaint(i)
+	println()
+    end
+end
+
+function optflow(imgf0, imgf1; name, params...)
+    size(imgf0) == size(imgf1) ||
+	throw(ArgumentError("non-matching dimensions"))
+
+    m = 2
+    mesh = init_grid(imgf0; type=:vertex)
+
+    # optflow specific stuff
+    Vg = FeSpace(mesh, P1(), (1,))
+    f0 = FeFunction(Vg, name="f0")
+    f1 = FeFunction(Vg, name="f1")
+    fw = FeFunction(Vg, name="fw")
+    nablafw = nabla(fw)
+
+    T(tdata, u) = tdata * u
+    S(u, nablau) = nablau
+
+    ctx = L1L2TVContext(name, mesh, m; T, tdata = nablafw, S, params...)
+
+    # FIXME: currently dual grid only
+    interpolate!(f0, x -> imgf0[round.(Int, x)...])
+    interpolate!(f1, x -> imgf1[round.(Int, x)...])
+    fw.data .= f1.data
+
+    g_optflow(x; u, f0, fw, nablafw) =
+	nablafw * u - (fw - f0)
+    interpolate!(ctx.g, g_optflow; ctx.u, f0, fw, nablafw)
+
+    save_optflow(i) =
+	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, f0, f1, fw)
+
+    pvd = paraview_collection("$(ctx.name).pvd")
+    pvd[0] = save_optflow(0)
+    for i in 1:10
+	step!(ctx)
+	pvd[i] = save_optflow(i)
+	println()
+    end
+end
+
+
+
+
+
 
 function myrun()
     name = "test"
@@ -17,11 +283,6 @@ function myrun()
 
     # inpainting
     mask = FeFunction(Vg, name="mask")
-    # optflow
-    f0 = FeFunction(Vg, name="f0")
-    f1 = FeFunction(Vg, name="f1")
-    fw = FeFunction(Vg, name="fw")
-    nablafw = nabla(fw)
 
     g = FeFunction(Vg, name="g")
     u = FeFunction(Vu, name="u")
@@ -49,7 +310,6 @@ function myrun()
     gamma2 = 1e-3
 
     interpolate!(g, x -> norm(x - SA[0.5, 0.5]) < 0.3)
-    interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1)
 
     interpolate!(f0, x -> x[1])
     interpolate!(f1, x -> x[1] - 0.01)
-- 
GitLab