From 7daf94ff57f9bf58d962d8e6395257ff15c50b1b Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Mon, 19 Jul 2021 18:30:32 +0200
Subject: [PATCH] fix energy calculation

---
 .gitignore      |  1 +
 Manifest.toml   | 24 ++++++++++++
 Project.toml    |  2 +
 src/function.jl |  3 +-
 src/run.jl      | 99 ++++++++++++++-----------------------------------
 5 files changed, 57 insertions(+), 72 deletions(-)

diff --git a/.gitignore b/.gitignore
index 1fcac93..67c77d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
 *.vtu
 *.pvd
+data/*
diff --git a/Manifest.toml b/Manifest.toml
index 09db26a..11a9867 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -21,6 +21,12 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
 uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
 version = "0.7.0"
 
+[[ColorTypes]]
+deps = ["FixedPointNumbers", "Random"]
+git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
+uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
+version = "0.11.0"
+
 [[CommonSubexpressions]]
 deps = ["MacroTools", "Test"]
 git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
@@ -71,12 +77,24 @@ version = "0.8.5"
 deps = ["ArgTools", "LibCURL", "NetworkOptions"]
 uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
 
+[[FileIO]]
+deps = ["Pkg", "Requires", "UUIDs"]
+git-tree-sha1 = "256d8e6188f3f1ebfa1a5d17e072a0efafa8c5bf"
+uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
+version = "1.10.1"
+
 [[FillArrays]]
 deps = ["LinearAlgebra", "Random", "SparseArrays"]
 git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
 uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
 version = "0.11.7"
 
+[[FixedPointNumbers]]
+deps = ["Statistics"]
+git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc"
+uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
+version = "0.8.4"
+
 [[ForwardDiff]]
 deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
 git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
@@ -193,6 +211,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
 deps = ["Serialization"]
 uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 
+[[Requires]]
+deps = ["UUIDs"]
+git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
+uuid = "ae029012-a4dd-5104-9daa-d747884805df"
+version = "1.1.3"
+
 [[SHA]]
 uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
 
diff --git a/Project.toml b/Project.toml
index 343cbf9..c3c675c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -4,6 +4,8 @@ authors = ["Stephan Hilb <stephan@ecshi.net>"]
 version = "0.1.0"
 
 [deps]
+ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
+FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
 ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
diff --git a/src/function.jl b/src/function.jl
index b4acb7a..c454677 100644
--- a/src/function.jl
+++ b/src/function.jl
@@ -202,7 +202,7 @@ function integrate(mesh::Mesh, expr; params...)
 	for k in axes(qx, 2)
 	    xhat = qx[:, k]
 	    x = elmap(mesh, cell)(xhat)
-	    opvalues = map(f -> evaluate(f, x), opparams)
+	    opvalues = map(f -> evaluate(f, xhat), opparams)
 
 	    val += qw[k] * expr(x; opvalues...) * intel
 	end
@@ -211,6 +211,7 @@ function integrate(mesh::Mesh, expr; params...)
 end
 
 function bind!(f::FeFunction, cell)
+    # TODO: make this non-allocating
     f.ldata .= vec(f.data[f.space.dofmap[:, :, cell]])
     return f
 end
diff --git a/src/run.jl b/src/run.jl
index 39639fc..a6ec678 100644
--- a/src/run.jl
+++ b/src/run.jl
@@ -1,7 +1,11 @@
-export myrun, denoise, inpaint, optflow, solve_primal!, estimate!
+export myrun, denoise, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
 
 using LinearAlgebra: norm
 
+# avoid world-age-issues by preloading ColorTypes
+import ColorTypes
+import FileIO
+
 
 struct L1L2TVContext{M, Ttype, Stype}
     name::String
@@ -179,9 +183,9 @@ function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
     u.data .= A \ b
 end
 
-function estimate!(ctx::L1L2TVContext)
-    huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
+huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
 
+function estimate!(ctx::L1L2TVContext)
     function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
 	alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * (
 	    huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
@@ -252,7 +256,7 @@ function mark(ctx::L1L2TVContext; theta=0.5)
 end
 
 
-function save(ctx::L1L2TVContext, filename, fs...)
+function output(ctx::L1L2TVContext, filename, fs...)
     print("save \"$filename\" ... ")
     vtk = vtk_mesh(filename, ctx.mesh)
     vtk_append!(vtk, fs...)
@@ -260,67 +264,18 @@ function save(ctx::L1L2TVContext, filename, fs...)
     return vtk
 end
 
-function test_refine(mesh = nothing)
-    m = 1
+toimg(arr) = Gray.(clamp.(arr, 0., 1.))
+loadimg(x) = reverse(transpose(Float64.(FileIO.load(x))); dims = 2)
+saveimg(io, x) = FileIO.save(io, toimg(x))
 
-    if isnothing(mesh)
-	mesh = init_grid(2, 2)
+function primal_energy(ctx::L1L2TVContext)
+    function integrand(x; g, u, nablau, tdata)
+	return ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
+	    ctx.alpha2 / 2 * norm(ctx.T(tdata, u) - g)^2 +
+	    ctx.beta / 2 * norm(ctx.S(u, nablau))^2 +
+	    ctx.lambda * huber(norm(nablau), ctx.gamma2)
     end
-
-    T(tdata, u) = u
-    S(u, nablau) = u
-    ctx = L1L2TVContext("test", mesh, m; T, tdata = nothing, S,
-	alpha1=0., alpha2=5., lambda=0.1, beta=1e-3, gamma1=1e-3, gamma2=1e-3)
-
-    interpolate!(ctx.g, x -> x[1] + x[2] > 1.01)
-
-    mysave(ctx, i) =
-	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
-	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
-
-    pvd = paraview_collection("$(ctx.name).pvd")
-    for i in 1:10
-	step!(ctx)
-    end
-    pvd[0] = mysave(ctx, 0)
-    ctx2 = refine(ctx, [1, 2, 3, 4, 5, 6, 7, 8])
-    #interpolate!(ctx2.g, x -> x[1] + x[2] < 2.01)
-    #interpolate!(ctx2.g, x -> norm(x .- m) < norm(m) / 3)
-    #ctx2.est.data .= 0
-    #ctx2.g.data .= 0
-    #ctx2.u.data .= 0
-    #ctx2.p1.data .= 0
-    #ctx2.p2.data .= 0
-    #ctx2.du.data .= 0
-    #ctx2.dp1.data .= 0
-    #ctx2.dp2.data .= 0
-    for i in 1:10
-	step!(ctx2)
-    end
-    pvd[1] = mysave(ctx2, 1)
-
-    test_mesh(ctx2.mesh)
-
-    return ctx, ctx2
-    for i in 1:5
-	step!(ctx)
-	estimate!(ctx)
-	pvd[i] = mysave(i)
-	println()
-    end
-    estmean = mean(ctx.est.data)
-    marked_cells = findall(>(2*estmean), ctx.est.data)
-    println(marked_cells)
-    ctx = refine(ctx, marked_cells)
-    interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
-    for i in 6:11
-	step!(ctx)
-	estimate!(ctx)
-	pvd[i] = mysave(i)
-	println()
-    end
-
-    return ctx
+    return integrate(ctx.mesh, integrand; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
 end
 
 norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
@@ -336,16 +291,17 @@ function denoise(img; name, params...)
 
     interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
     m = (size(img) .- 1) ./ 2 .+ 1
-    interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
+    #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
 
     save_denoise(ctx, i) =
-	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
 	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
 
-    pvd = paraview_collection("$(ctx.name).pvd")
+    pvd = paraview_collection("output/$(ctx.name).pvd")
     pvd[0] = save_denoise(ctx, 0)
 
     k = 0
+    println("primal energy: $(primal_energy(ctx))")
     while true
 	while true
 	    k += 1
@@ -354,6 +310,7 @@ function denoise(img; name, params...)
 	    pvd[k] = save_denoise(ctx, k)
 	    println()
 	    println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
+	    println("primal energy: $(primal_energy(ctx))")
 
 	    norm_step = sqrt(norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2))
             norm_step <= 1e-3 && break
@@ -363,7 +320,7 @@ function denoise(img; name, params...)
 	println("refining ...")
 	ctx = refine(ctx, marked_cells)
 	test_mesh(ctx.mesh)
-	interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
+	#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
 	k >= 200 && break
     end
     vtk_save(pvd)
@@ -394,10 +351,10 @@ function inpaint(img, imgmask; name, params...)
     interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
 
     save_inpaint(i) =
-	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
 	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask)
 
-    pvd = paraview_collection("$(ctx.name).pvd")
+    pvd = paraview_collection("output/$(ctx.name).pvd")
     pvd[0] = save_inpaint(0)
     for i in 1:3
 	step!(ctx)
@@ -437,10 +394,10 @@ function optflow(imgf0, imgf1; name, params...)
     interpolate!(ctx.g, g_optflow; ctx.u, f0, fw, nablafw)
 
     save_optflow(i) =
-	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
 	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, f0, f1, fw)
 
-    pvd = paraview_collection("$(ctx.name).pvd")
+    pvd = paraview_collection("output/$(ctx.name).pvd")
     pvd[0] = save_optflow(0)
     for i in 1:10
 	step!(ctx)
-- 
GitLab