From 01bd00dac5b7a86b88034d57adb4d13817cb25ab Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Mon, 17 Jan 2022 18:21:42 +0100
Subject: [PATCH] get optical flow warping to work

the final missing piece was the correct space for data g, since tdata is
actually discontinuous.
---
 scripts/run_experiments.jl | 57 ++++++++++++++++++++++++++++----------
 1 file changed, 43 insertions(+), 14 deletions(-)

diff --git a/scripts/run_experiments.jl b/scripts/run_experiments.jl
index 5bbf6eb..8a6075f 100644
--- a/scripts/run_experiments.jl
+++ b/scripts/run_experiments.jl
@@ -1,4 +1,4 @@
-using LinearAlgebra: norm, dot
+using LinearAlgebra: norm, dot, I
 
 using Colors: Gray
 # avoid world-age-issues by preloading ColorTypes
@@ -77,7 +77,8 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
     d = ndims_domain(mesh)
 
     Vest = FeSpace(mesh, DP0(), (1,))
-    Vg = FeSpace(mesh, P1(), (1,))
+    # DP1 only for optical flow
+    Vg = FeSpace(mesh, DP1(), (1,))
     Vu = FeSpace(mesh, P1(), (m,))
     Vp1 = FeSpace(mesh, P1(), (1,))
     Vp2 = FeSpace(mesh, DP0(), (m, d))
@@ -376,6 +377,8 @@ end
 
 huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
 
+# this computes the primal-dual error indicator which is not really useful
+# if not computed on a finer mesh than `u` was solved on
 function estimate!(ctx::L1L2TVContext)
     function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
 	alpha1part =
@@ -870,11 +873,13 @@ function optflow(ctx)
     f0 = FeFunction(Vg, name="f0")
     f1 = FeFunction(Vg, name="f1")
     fw = FeFunction(Vg, name="fw")
+    Vdg = FeSpace(mesh, DP0(), (1, 2))
+    tdata = FeFunction(Vdg, name="tdata")
 
     T(tdata, u) = tdata * u # here tdata will be nabla(fw)
     S(u, nablau) = nablau
 
-    st = L1L2TVContext("run", mesh, m; T, tdata = nabla(fw), S,
+    st = L1L2TVContext("run", mesh, m; T, tdata, S,
         ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta,
         ctx.params.gamma1, ctx.params.gamma2)
 
@@ -882,13 +887,25 @@ function optflow(ctx)
     u_acc.data .= 0
 
     function warp!()
+        # warp image into imgfw / fw
         imgfw = warp_backwards(imgf1, sample(u_acc))
         project_l2_pixel!(fw, imgfw)
 
-        g_optflow(x; u, f0, fw, nablafw) =
-            -(fw - f0)
-            #nablafw * u - (fw - f0)
-        interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw = st.tdata)
+        # recompute optflow operator T based on u0 and fw
+        # TODO: investigate what julia does for singular matrices here
+        function tdata_optflow(x; u0_deriv, nablafw)
+            res = nablafw / (I + u0_deriv')
+            all(isfinite, res) || throw(DivideError("singular optflow matrix"))
+            return res
+        end
+        interpolate!(tdata, tdata_optflow;
+            u0_deriv = nabla(st.u), nablafw = nabla(fw))
+
+        # recompute optflow data g
+        g_optflow(x; u0, f0, fw, tdata) =
+            #-(fw - f0)
+            T(tdata, u0) - (fw - f0)
+        interpolate!(st.g, g_optflow; u0 = st.u, f0, fw, st.tdata)
     end
 
     function interpolate_image_data!()
@@ -910,15 +927,17 @@ function optflow(ctx)
 
         #norm_g_old = norm_l2(st.g)
 
-        for j in 1:3
+        for j in 1:5
+            i += 1
 
+            # interior newton loop
             norm_step_ = Inf
-            while norm_step_ > tol_newton
+            k = 0
+            while norm_step_ > tol_newton && k < 20
+                k += 1
                 println()
-                i += 1
                 step!(st)
                 #estimate!(st)
-                pvd[i] = save_step(i)
 
                 norm_step_ = norm_step(st) / sqrt(mesh_area)
                 println("norm_step = $norm_step_")
@@ -931,13 +950,17 @@ function optflow(ctx)
                 #display(plot(colorflow(to_img(sample(st.u)); ctx.params.maxflow)))
             end
 
-            u_acc.data .+= st.u.data
+            #estimate!(st)
+            #pvd[i] = save_step(i)
+
+            u_acc.data .= st.u.data
             display(plot(colorflow(to_img(sample(u_acc)); ctx.params.maxflow)))
 
             # poor man's hash
             #println(integrate(st.mesh, (x; tdata) -> tdata; st.tdata))
 
             #println("refine ...")
+            ##marked_cells = mark(st; theta = 0.5)
             #marked_cells = Set(axes(mesh.cells, 2))
             #mesh, fs = refine(mesh, marked_cells;
             #    st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2,
@@ -947,13 +970,19 @@ function optflow(ctx)
             #    fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
             #f0, f1, fw, u_acc = (fs.f0, fs.f1, fs.fw, fs.u_acc)
 
-            ##i += 1
-            #pvd[i] = save_step(i)
+            #i += 1
 
             #println("interpolate image data ...")
             #interpolate_image_data!()
 
+
+
             println("warp ...")
+            ##warp!() # yay
+            estimate!(st)
+            pvd[i] = save_step(i)
+
+            #u_acc.data .= st.u.data
             warp!() # yay
         end
 
-- 
GitLab