From f0526296aadf408354d803336f9b921c54debd74 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Thu, 8 Oct 2020 20:23:31 +0200
Subject: [PATCH] add latent changes

---
 src/chambolle.jl | 16 ----------------
 src/dualtvdd.jl  |  5 ++---
 src/problems.jl  | 31 +++++++++++++++++++++++++++++--
 3 files changed, 31 insertions(+), 21 deletions(-)

diff --git a/src/chambolle.jl b/src/chambolle.jl
index fd8d0b9..4cd2ae8 100644
--- a/src/chambolle.jl
+++ b/src/chambolle.jl
@@ -105,19 +105,3 @@ function step!(ctx::ChambolleState)
 end
 
 fetch(ctx::ChambolleState) = ctx.p
-
-function recover_u(p, md::DualTVL1ROFOpProblem)
-    d = ndims(md.g)
-    u = similar(md.g)
-    v = similar(md.g)
-
-    @inline kfΛ(w) = @inbounds divergence(w)
-    kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
-
-    # v = div(p) + A'*f
-    map!(kΛ, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
-    v .+= md.g
-    # u = B * v
-    mul!(vec(u), md.B, vec(v))
-    return u
-end
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index 9745045..7de8732 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -70,14 +70,13 @@ end
 
 function step!(ctx::DualTVDDState)
     alg = ctx.algorithm
-    # σ = 1 takes care of sequential updates
-    σ = alg.parallel ? ctx.algorithm.σ : 1.
+    σ = ctx.algorithm.σ
     d = ndims(ctx.p)
     ax = axes(ctx.p)
     overlap = ctx.algorithm.overlap
 
     # call run! on each cell (this can be threaded)
-    Threads.@threads for i in eachindex(ctx.subax)
+    for i in eachindex(ctx.subax)
         sax = ctx.subax[i]
         li = LinearIndices(ctx.subax)[i]
         sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround
diff --git a/src/problems.jl b/src/problems.jl
index d37dfa3..be32f99 100644
--- a/src/problems.jl
+++ b/src/problems.jl
@@ -26,9 +26,36 @@ function energy(p, prob::DualTVL1ROFOpProblem)
     # v = div(p) + g
     v = map(kΛ, extend(p, ExtensionNothing()))
 
-    # |v|_B^2 / 2
+    # |v|_B^2
     u = prob.B * vec(v)
-    return sum(dot.(u, vec(v))) / 2
+    return sum(dot.(u, vec(v)))
+end
+
+function recover_u(p, prob::DualTVL1ROFOpProblem)
+    d = ndims(prob.g)
+    u = similar(prob.g)
+    v = similar(prob.g)
+
+    @inline kfΛ(w) = @inbounds divergence(w)
+    kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+    # v = div(p) + A'*f
+    map!(kΛ, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
+    v .+= prob.g
+    # u = B * v
+    mul!(vec(u), prob.B, vec(v))
+    return u
+end
+
+function residual(p, prob::DualTVL1ROFOpProblem)
+    d = ndims(p)
+    grad = Kernel{ntuple(_->0:1, d)}(gradient)
+
+    u = recover_u(p, prob)
+    q = map(grad, StaticKernels.extend(u, StaticKernels.ExtensionReplicate()))
+
+    res = q .- norm.(q) .* p ./ prob.λ
+    return sum(dot.(res, res)) / length(p)
 end
 
 # operator norm of B
-- 
GitLab