From b8c4ac1416537d3a48fff0a37f9966d4bdbc2150 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Mon, 17 Jan 2022 20:57:32 +0100
Subject: [PATCH] refactor

---
 scripts/run_experiments.jl | 142 +++++++++++++++++++++++--------------
 1 file changed, 88 insertions(+), 54 deletions(-)

diff --git a/scripts/run_experiments.jl b/scripts/run_experiments.jl
index 8a6075f..a8d7e39 100644
--- a/scripts/run_experiments.jl
+++ b/scripts/run_experiments.jl
@@ -22,15 +22,15 @@ isdefined(Main, :Revise) && Revise.track(joinpath(@__DIR__, "util.jl"))
 
 using .Util
 
-_toimg(arr) = Gray.(clamp.(arr, 0., 1.))
-#loadimg(x) = reverse(transpose(Float64.(FileIO.load(x))); dims = 2)
-#saveimg(io, x) = FileIO.save(io, transpose(reverse(_toimg(x); dims = 2)))
+grayclamp(value) = Gray(clamp(value, 0., 1.))
 loadimg(x) = Float64.(FileIO.load(x))
-saveimg(io, x) = FileIO.save(io, _toimg(x))
+saveimg(io, x) = FileIO.save(io, grayclamp.(x))
 
+# convert image to/from image coordinate system
 from_img(arr) = permutedims(reverse(arr; dims = 1))
 to_img(arr) = permutedims(reverse(arr; dims = 2))
 function to_img(arr::AbstractArray{<:Any,3})
+    # for flow fields handle flow direction in first dimension too
     out = permutedims(reverse(arr; dims = (1,3)), (1, 3, 2))
     out[1, :, :] .= .-out[1, :, :]
     return out
@@ -45,8 +45,7 @@ equidistant datapoints on a log-scale. `a` controls density.
 """
 logfilter(dt; a=20) = filter(:k => k -> round(a*log(k+1)) - round(a*log(k)) > 0, dt)
 
-struct L1L2TVContext{M, Ttype, Stype}
-    name::String
+struct L1L2TVState{M, Ttype, Stype}
     mesh::M
     d::Int # = ndims_domain(mesh)
     m::Int
@@ -72,25 +71,65 @@ struct L1L2TVContext{M, Ttype, Stype}
     dp2::FeFunction
 end
 
-function L1L2TVContext(name, mesh, m; T, tdata, S,
+function L1L2TVState(mesh, m; T, tdata, S,
 	alpha1, alpha2, beta, lambda, gamma1, gamma2)
     d = ndims_domain(mesh)
 
+    Vest = FeSpace(mesh, DP0(), (1,))
+    Vg = FeSpace(mesh, P1(), (1,))
+    Vu = FeSpace(mesh, P1(), (m,))
+    Vp1 = FeSpace(mesh, P1(), (1,))
+    Vp2 = FeSpace(mesh, DP0(), (m, d))
+
+    est = FeFunction(Vest, name = "est")
+    g = FeFunction(Vg, name = "g")
+    u = FeFunction(Vu, name = "u")
+    p1 = FeFunction(Vp1, name = "p1")
+    p2 = FeFunction(Vp2, name = "p2")
+    du = FeFunction(Vu; name = "du")
+    dp1 = FeFunction(Vp1; name = "dp1")
+    dp2 = FeFunction(Vp2; name = "dp2")
+
+    est.data .= 0
+    g.data .= 0
+    u.data .= 0
+    p1.data .= 0
+    p2.data .= 0
+    du.data .= 0
+    dp1.data .= 0
+    dp2.data .= 0
+
+    return L1L2TVState(mesh, d, m, T, tdata, S,
+	alpha1, alpha2, beta, lambda, gamma1, gamma2,
+	est, g, u, p1, p2, du, dp1, dp2)
+end
+
+function OptFlowState(mesh;
+	alpha1, alpha2, beta, lambda, gamma1, gamma2)
+    d = ndims_domain(mesh)
+    m = 2
+
     Vest = FeSpace(mesh, DP0(), (1,))
     # DP1 only for optical flow
     Vg = FeSpace(mesh, DP1(), (1,))
     Vu = FeSpace(mesh, P1(), (m,))
     Vp1 = FeSpace(mesh, P1(), (1,))
     Vp2 = FeSpace(mesh, DP0(), (m, d))
+    Vdg = FeSpace(mesh, DP0(), (1, d))
 
-    est = FeFunction(Vest, name="est")
-    g = FeFunction(Vg, name="g")
-    u = FeFunction(Vu, name="u")
-    p1 = FeFunction(Vp1, name="p1")
-    p2 = FeFunction(Vp2, name="p2")
+    # tdata will be something like nabla(fw)
+    T(tdata, u) = tdata * u
+    S(u, nablau) = nablau
+
+    est = FeFunction(Vest, name = "est")
+    g = FeFunction(Vg, name = "g")
+    u = FeFunction(Vu, name = "u")
+    p1 = FeFunction(Vp1, name = "p1")
+    p2 = FeFunction(Vp2, name = "p2")
     du = FeFunction(Vu; name = "du")
     dp1 = FeFunction(Vp1; name = "dp1")
     dp2 = FeFunction(Vp2; name = "dp2")
+    tdata = FeFunction(Vdg, name = "tdata")
 
     est.data .= 0
     g.data .= 0
@@ -101,7 +140,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
     dp1.data .= 0
     dp2.data .= 0
 
-    return L1L2TVContext(name, mesh, d, m, T, tdata, S,
+    return L1L2TVState(mesh, d, m, T, tdata, S,
 	alpha1, alpha2, beta, lambda, gamma1, gamma2,
 	est, g, u, p1, p2, du, dp1, dp2)
 end
@@ -126,7 +165,7 @@ function p2_project!(p2, lambda)
     end
 end
 
-function step!(ctx::L1L2TVContext)
+function step!(ctx::L1L2TVState)
     T = ctx.T
     S = ctx.S
     alpha1 = ctx.alpha1
@@ -214,7 +253,7 @@ end
 2010, Chambolle and Pock: primal-dual semi-implicit algorithm
 2017, Alkämper and Langer: fem dualisation
 "
-function step_pd!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
+function step_pd!(ctx::L1L2TVState; sigma, tau, theta = 1.)
     # note: ignores gamma1, gamma2, beta and uses T = I, lambda = 1, m = 1!
     # changed alpha2 -> alpha2 / 2
     # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
@@ -286,7 +325,7 @@ end
 "
 2010, Chambolle and Pock: accelerated primal-dual semi-implicit algorithm
 "
-function step_pd2!(ctx::L1L2TVContext; sigma, tau, theta = 1.)
+function step_pd2!(ctx::L1L2TVState; sigma, tau, theta = 1.)
     # chambolle-pock require: sigma * tau * L^2 <= 1, L = |grad|
 
     # u is P1
@@ -350,7 +389,7 @@ end
 "
 2004, Chambolle: dual semi-implicit algorithm
 "
-function step_d!(ctx::L1L2TVContext; tau)
+function step_d!(ctx::L1L2TVState; tau)
     # u is P1
     # p2 is essentially DP0 (technically may be DP1)
 
@@ -359,7 +398,7 @@ function step_d!(ctx::L1L2TVContext; tau)
     return ctx
 end
 
-function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
+function solve_primal!(u::FeFunction, ctx::L1L2TVState)
     u_a(x, u, nablau, phi, nablaphi; g, p1, p2, tdata) =
 	ctx.alpha2 * dot(ctx.T(tdata, u), ctx.T(tdata, phi)) +
 	    ctx.beta * dot(ctx.S(u, nablau), ctx.S(phi, nablaphi))
@@ -379,7 +418,7 @@ huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
 
 # this computes the primal-dual error indicator which is not really useful
 # if not computed on a finer mesh than `u` was solved on
-function estimate!(ctx::L1L2TVContext)
+function estimate!(ctx::L1L2TVState)
     function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
 	alpha1part =
             ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) -
@@ -410,11 +449,11 @@ function estimate!(ctx::L1L2TVContext)
         nablau = nabla(ctx.u), w, nablaw = nabla(w), ctx.tdata)
 end
 
-estimate_error(st::L1L2TVContext) =
+estimate_error(st::L1L2TVState) =
     sum(st.est.data) / area(st.mesh)
 
 # minimal Dörfler marking
-function mark(ctx::L1L2TVContext; theta=0.5)
+function mark(ctx::L1L2TVState; theta=0.5)
     n = ncells(ctx.mesh)
     esttotal = sum(ctx.est.data)
 
@@ -432,7 +471,7 @@ function mark(ctx::L1L2TVContext; theta=0.5)
 end
 
 
-function output(ctx::L1L2TVContext, filename, fs...)
+function output(ctx::L1L2TVState, filename, fs...)
     print("save \"$filename\" ... ")
     vtk = vtk_mesh(filename, ctx.mesh)
     vtk_append!(vtk, fs...)
@@ -440,7 +479,7 @@ function output(ctx::L1L2TVContext, filename, fs...)
     return vtk
 end
 
-function primal_energy(ctx::L1L2TVContext)
+function primal_energy(ctx::L1L2TVState)
     function integrand(x; g, u, nablau, tdata)
 	return ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
 	    ctx.alpha2 / 2 * norm(ctx.T(tdata, u) - g)^2 +
@@ -451,7 +490,7 @@ function primal_energy(ctx::L1L2TVContext)
         nablau = nabla(ctx.u), ctx.tdata)
 end
 
-function dual_energy(st::L1L2TVContext)
+function dual_energy(st::L1L2TVState)
     # primal reconstruction
     w = FeFunction(st.u.space)
     solve_primal!(w, st)
@@ -473,10 +512,10 @@ end
 
 norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
 
-norm_step(ctx::L1L2TVContext) =
+norm_step(ctx::L1L2TVState) =
     sqrt((norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)^2) / area(ctx.mesh))
 
-function norm_residual(ctx::L1L2TVContext)
+function norm_residual(ctx::L1L2TVState)
     w = FeFunction(ctx.u.space)
     solve_primal!(w, ctx)
     w.data .-= ctx.u.data
@@ -495,7 +534,7 @@ function norm_residual(ctx::L1L2TVContext)
     return sqrt(upart2 + ppart2)
 end
 
-function denoise(img; name, params...)
+function denoise(img; params...)
     m = 1
     img = from_img(img) # coord flip
     #mesh = init_grid(img; type=:vertex)
@@ -504,7 +543,7 @@ function denoise(img; name, params...)
     T(tdata, u) = u
     S(u, nablau) = u
 
-    ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
+    ctx = L1L2TVState(mesh, m; T, tdata = nothing, S, params...)
 
     project_l2_lagrange!(ctx.g, img)
     #interpolate!(ctx.g, x -> evaluate_bilinear(img, x))
@@ -552,7 +591,7 @@ function denoise(img; name, params...)
 end
 
 
-function denoise_pd(st, img; df=nothing, name, algorithm, params_...)
+function denoise_pd(st, img; df=nothing, algorithm, params_...)
     params = NamedTuple(params_)
     m = 1
     img = from_img(img) # coord flip
@@ -562,7 +601,7 @@ function denoise_pd(st, img; df=nothing, name, algorithm, params_...)
     T(tdata, u) = u
     S(u, nablau) = u
 
-    st = L1L2TVContext(name, mesh, m;
+    st = L1L2TVState(mesh, m;
         T, tdata = nothing, S,
         params.alpha1, params.alpha2, params.lambda, params.beta,
         params.gamma1, params.gamma2)
@@ -647,11 +686,11 @@ function experiment_pd_comparison(ctx)
     df2 = DataFrame()
     df3 = DataFrame()
 
-    st1 = denoise_pd(ctx, img; name="test",
+    st1 = denoise_pd(ctx, img;
         algorithm=:pd1, df = df1, algparams...);
-    st2 = denoise_pd(ctx, img; name="test",
+    st2 = denoise_pd(ctx, img;
         algorithm=:pd2, df = df2, algparams...);
-    st3 = denoise_pd(ctx, img; name="test",
+    st3 = denoise_pd(ctx, img;
         algorithm=:newton, df = df3, algparams...);
 
     energy_min = min(minimum(df1.primal_energy), minimum(df2.primal_energy),
@@ -687,7 +726,7 @@ function denoise_approximation(ctx)
     S(u, nablau) = u
     #S(u, nablau) = nablau
 
-    st = L1L2TVContext(ctx.params.name, ctx.params.mesh, 1;
+    st = L1L2TVState(ctx.params.mesh, 1;
         T, tdata = nothing, S,
         ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta,
         ctx.params.gamma1, ctx.params.gamma2)
@@ -721,7 +760,7 @@ function denoise_approximation(ctx)
             marked_cells = Set(axes(st.mesh.cells, 2))
             mesh, fs = refine(st.mesh, marked_cells;
                 st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
-            st = L1L2TVContext("run", mesh, st.d, st.m, T, nothing, S,
+            st = L1L2TVState(mesh, st.d, st.m, T, nothing, S,
                 st.alpha1, st.alpha2, st.beta, st.lambda,
                 st.gamma1, st.gamma2,
                 fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
@@ -729,7 +768,7 @@ function denoise_approximation(ctx)
         #marked_cells = Set(axes(st.mesh.cells, 2))
         #mesh2, fs = refine(st.mesh, marked_cells;
         #    st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
-        #st2 = L1L2TVContext("run", mesh2, st.d, st.m, T, nothing, S,
+        #st2 = L1L2TVState(mesh2, st.d, st.m, T, nothing, S,
         #    st.alpha1, st.alpha2, st.beta, st.lambda,
         #    st.gamma1, st.gamma2,
         #    fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
@@ -773,7 +812,7 @@ function denoise_approximation(ctx)
                     Set(axes(st.mesh.cells, 2))
                 mesh, fs = refine(st.mesh, marked_cells;
                     st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
-                st = L1L2TVContext("run", mesh, st.d, st.m, T, nothing, S,
+                st = L1L2TVState(mesh, st.d, st.m, T, nothing, S,
                     st.alpha1, st.alpha2, st.beta, st.lambda,
                     st.gamma1, st.gamma2,
                     fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
@@ -813,7 +852,7 @@ function experiment_approximation(ctx)
     ))
 end
 
-function inpaint(img, imgmask; name, params...)
+function inpaint(img, imgmask; params...)
     size(img) == size(imgmask) ||
 	throw(ArgumentError("non-matching dimensions"))
 
@@ -824,12 +863,12 @@ function inpaint(img, imgmask; name, params...)
 
     # inpaint specific stuff
     Vg = FeSpace(mesh, P1(), (1,))
-    mask = FeFunction(Vg, name="mask")
+    mask = FeFunction(Vg, name = "mask")
 
     T(tdata, u) = isone(tdata[begin]) ? u : zero(u)
     S(u, nablau) = u
 
-    ctx = L1L2TVContext(name, mesh, m; T, tdata = mask, S, params...)
+    ctx = L1L2TVState(mesh, m; T, tdata = mask, S, params...)
 
     # FIXME: currently dual grid only
     interpolate!(mask, x -> imgmask[round.(Int, x)...])
@@ -870,20 +909,15 @@ function optflow(ctx)
 
     # optflow specific stuff
     Vg = FeSpace(mesh, P1(), (1,))
-    f0 = FeFunction(Vg, name="f0")
-    f1 = FeFunction(Vg, name="f1")
-    fw = FeFunction(Vg, name="fw")
-    Vdg = FeSpace(mesh, DP0(), (1, 2))
-    tdata = FeFunction(Vdg, name="tdata")
-
-    T(tdata, u) = tdata * u # here tdata will be nabla(fw)
-    S(u, nablau) = nablau
+    f0 = FeFunction(Vg, name = "f0")
+    f1 = FeFunction(Vg, name = "f1")
+    fw = FeFunction(Vg, name = "fw")
 
-    st = L1L2TVContext("run", mesh, m; T, tdata, S,
+    st = OptFlowState(mesh;
         ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta,
         ctx.params.gamma1, ctx.params.gamma2)
 
-    u_acc = FeFunction(st.u.space, name="u_acc")
+    u_acc = FeFunction(st.u.space, name = "u_acc")
     u_acc.data .= 0
 
     function warp!()
@@ -898,13 +932,13 @@ function optflow(ctx)
             all(isfinite, res) || throw(DivideError("singular optflow matrix"))
             return res
         end
-        interpolate!(tdata, tdata_optflow;
+        interpolate!(st.tdata, tdata_optflow;
             u0_deriv = nabla(st.u), nablafw = nabla(fw))
 
         # recompute optflow data g
         g_optflow(x; u0, f0, fw, tdata) =
             #-(fw - f0)
-            T(tdata, u0) - (fw - f0)
+            st.T(tdata, u0) - (fw - f0)
         interpolate!(st.g, g_optflow; u0 = st.u, f0, fw, st.tdata)
     end
 
@@ -965,7 +999,7 @@ function optflow(ctx)
             #mesh, fs = refine(mesh, marked_cells;
             #    st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2,
             #    f0, f1, fw, u_acc)
-            #st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fs.fw), S,
+            #st = L1L2TVState(mesh, st.d, st.m, T, nabla(fs.fw), S,
             #    st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2,
             #    fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
             #f0, f1, fw, u_acc = (fs.f0, fs.f1, fs.fw, fs.u_acc)
@@ -1002,7 +1036,7 @@ function optflow(ctx)
         #mesh, fs = refine(mesh, marked_cells;
         #    st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2,
         #    f0, f1, fw)
-        #st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fs.fw), S,
+        #st = L1L2TVState(mesh, st.d, st.m, T, nabla(fs.fw), S,
         #    st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2,
         #    fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
         #f0, f1, fw = (fs.f0, fs.f1, fs.fw)
-- 
GitLab