From da8b54db6dd149ba9aea1acf3b630b1b4945d329 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Fri, 24 Jul 2020 18:03:11 +0200
Subject: [PATCH] fix interface issues and tests

---
 src/DualTVDD.jl  | 16 ++++++++--------
 src/chambolle.jl |  5 ++---
 src/dualtvdd.jl  | 15 +++++++++------
 src/projgrad.jl  |  2 +-
 test/runtests.jl | 38 ++++++++++++++++++++++++++++----------
 5 files changed, 48 insertions(+), 28 deletions(-)

diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index af27db1..4fcafbd 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -71,13 +71,13 @@ end
 function alg_error(alg, pmin, niter, ninner=1)
     print("run ...")
     res = Float64[]
-    (p, ctx) = iterate(alg)
-    push!(res, error(ctx.p, pmin, ctx.algorithm.problem))
+    ctx = init(alg)
+    push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem))
     for i in 1:niter
         for j in 1:ninner
-            (p, ctx) = iterate(alg, ctx)
+            step!(ctx)
         end
-        push!(res, error(ctx.p, pmin, ctx.algorithm.problem))
+        push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem))
     end
     println(" finished")
     return res
@@ -87,9 +87,9 @@ function calc_energy(prob, niter)
     alg_ref = DualTVDD.ChambolleAlgorithm(prob)
     ctx = init(alg_ref)
     for i in 1:niter
-        (p, ctx) = iterate(alg_ref, ctx)
+        ctx = step!(ctx)
     end
-    return energy(ctx), p
+    return energy(ctx), fetch(ctx)
 end
 
 function rundd()
@@ -132,8 +132,8 @@ function rundd()
          lognan.(alg_error(alg_dd2, pmin, n)),
        ]
 
-    plot(y, xaxis=:log, yaxis=:log)
-
+    plt = plot(y, xaxis=:log, yaxis=:log)
+    display(plt)
 
     #display(energy(ctx))
     #display(ctx.p)
diff --git a/src/chambolle.jl b/src/chambolle.jl
index 0a183cd..1489379 100644
--- a/src/chambolle.jl
+++ b/src/chambolle.jl
@@ -90,8 +90,7 @@ function step!(ctx::ChambolleState)
     return ctx
 end
 
-fetch(ctx::ChambolleState) =
-    recover_u(ctx.p, ctx.algorithm.problem)
+fetch(ctx::ChambolleState) = ctx.p
 
 function recover_u(p, md::DualTVL1ROFOpProblem)
     d = ndims(md.g)
@@ -102,7 +101,7 @@ function recover_u(p, md::DualTVL1ROFOpProblem)
     kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
 
     # v = div(p) + A'*f
-    map!(kΛ, v, extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
+    map!(kΛ, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
     v .+= md.g
     # u = B * v
     mul!(vec(u), md.B, vec(v))
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index b54df22..f3925ff 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -29,7 +29,7 @@ struct DualTVDDState{A,d,V,SV,SAx,SC}
     subctx::Array{SC,d}
 end
 
-function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
+function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
     g = alg.problem.g
     d = ndims(g)
     ax = axes(g)
@@ -56,9 +56,9 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
     #subalg = [ProjGradAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
     subalg = [ChambolleAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
 
-    subctx = [iterate(x)[2] for x in subalg]
+    subctx = [init(x) for x in subalg]
 
-    return p, DualTVDDState(alg, p, q, subax, subctx)
+    return DualTVDDState(alg, p, q, subax, subctx)
 end
 
 function intersectin(a, b)
@@ -68,7 +68,8 @@ function intersectin(a, b)
     return (c, c .- az, c .- bz)
 end
 
-function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
+function step!(ctx::DualTVDDState)
+    alg = ctx.algorithm
     # σ = 1 takes care of sequential updates
     σ = alg.parallel ? ctx.algorithm.σ : 1.
     d = ndims(ctx.p)
@@ -103,7 +104,7 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
         map!(kΛ, sg, sq)
 
         for j in 1:alg.ninner
-            (_, ctx.subctx[i]) = iterate(ctx.subctx[i].algorithm, ctx.subctx[i])
+            step!(ctx.subctx[i])
         end
     end
 
@@ -113,9 +114,11 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
         ctx.p[sax...] .+= σ .* ctx.subctx[i].p
     end
 
-    return ctx.p, ctx
+    return ctx
 end
 
+fetch(ctx::DualTVDDState) = ctx.p
+
 @generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
     i0 = ntuple(_->0, N)
     i1(k) = ntuple(i->Int(k==i), N)
diff --git a/src/projgrad.jl b/src/projgrad.jl
index 6bbc2ed..e45080d 100644
--- a/src/projgrad.jl
+++ b/src/projgrad.jl
@@ -70,4 +70,4 @@ function step!(ctx::ProjGradState)
     return ctx
 end
 
-fetch(ctx::ProjGradState) = recover_u(ctx.p, ctx.algorithm.problem)
+fetch(ctx::ProjGradState) = ctx.p
diff --git a/test/runtests.jl b/test/runtests.jl
index fa69c23..e2debb3 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,17 +2,35 @@ using Test, BenchmarkTools
 using LinearAlgebra
 using DualTVDD:
     DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm,
-    init, step!, fetch
+    init, step!, fetch, recover_u
 
-g = Float64[0 2; 1 0]
-prob = DualTVL1ROFOpProblem(g, I, 1e-10)
+@testset "B = I" begin
+    g = Float64[0 2; 1 0]
+    prob = DualTVL1ROFOpProblem(g, I, 1e-10)
 
-@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
-    ctx = init(alg)
-    @test 0 == @ballocated step!($ctx)
-    for i in 1:100
-        step!(ctx)
+    @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
+        ctx = init(alg)
+        @test 0 == @ballocated step!($ctx)
+        for i in 1:100
+            step!(ctx)
+        end
+        u = recover_u(fetch(ctx), ctx.algorithm.problem)
+        @test u ≈ g
+    end
+end
+
+@testset "B = rand(...)" begin
+    g = Float64[0 2; 1 0]
+    B = rand(length(g), length(g))
+    prob = DualTVL1ROFOpProblem(g, B, 1e-10)
+
+    @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
+        ctx = init(alg)
+        @test 0 == @ballocated step!($ctx)
+        for i in 1:100
+            step!(ctx)
+        end
+        u = recover_u(fetch(ctx), ctx.algorithm.problem)
+        @test vec(u) ≈ B * vec(g)
     end
-    u = fetch(ctx)
-    @test u ≈ g
 end
-- 
GitLab