From a98a243097834d48ba70c10a7bfa9078de09c234 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Thu, 22 Oct 2020 14:52:55 +0200
Subject: [PATCH] add surrogate algorithm and ddseq coloring

---
 src/DualTVDD.jl  |  1 +
 src/chambolle.jl |  6 +++-
 src/common.jl    |  1 +
 src/dualtvdd.jl  | 87 ++++++++++++++++++++++++++++++++++--------------
 src/problems.jl  |  6 +++-
 src/surrogate.jl | 78 +++++++++++++++++++++++++++++++++++++++++++
 6 files changed, 152 insertions(+), 27 deletions(-)
 create mode 100644 src/surrogate.jl

diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index 98a625f..e4671fd 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -9,6 +9,7 @@ include("nstep.jl")
 include("chambolle.jl")
 include("dualtvdd.jl")
 include("projgrad.jl")
+include("surrogate.jl")
 #include("tvnewton.jl")
 
 #using Plots: heatmap
diff --git a/src/chambolle.jl b/src/chambolle.jl
index 4cd2ae8..882462e 100644
--- a/src/chambolle.jl
+++ b/src/chambolle.jl
@@ -93,13 +93,17 @@ function LinearAlgebra.mul!(Y::AbstractVector{<:SVector}, A, B::AbstractVector{<
 end
 
 function step!(ctx::ChambolleState)
+    display(ctx.p)
     alg = ctx.algorithm
     # r = div(p) + g
     map!(ctx.k1, ctx.r, ctx.p)
     # s = B * r
     mul!(ctx.sv, alg.problem.B, ctx.rv)
+    #display(ctx.algorithm.problem.g)
+    display(ctx.r)
     # p = (p + τ*grad(s)) / (1 + τ/λ|grad(s)|)
-    map!(ctx.k2, ctx.p, ctx.s)
+    ctx.p .= deepcopy(map!(ctx.k2, ctx.p, ctx.s))
+    ctx.p .+= 1
 
     return ctx
 end
diff --git a/src/common.jl b/src/common.jl
index 60958d8..7293ac0 100644
--- a/src/common.jl
+++ b/src/common.jl
@@ -51,6 +51,7 @@ function init end
 function step! end
 function fetch end
 
+fetch_u(st) = recover_u(fetch(st), st.algorithm.problem)
 
 Base.intersect(a::CartesianIndices{d}, b::CartesianIndices{d}) where d =
     CartesianIndices(intersect.(a.indices, b.indices))
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index 7de8732..191469e 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -1,3 +1,5 @@
+using Distributed: @everywhere, pmap
+
 struct DualTVDDAlgorithm{P,d} <: Algorithm{P}
     problem::P
 
@@ -68,6 +70,28 @@ function intersectin(a, b)
     return (c, c .- az, c .- bz)
 end
 
+function chessboard_coloring(sz)
+    binli = LinearIndices((2, 2))
+    coloring = [Int[] for _ in 1:4]
+
+    li = LinearIndices(sz)
+    for I in CartesianIndices(sz)
+        push!(coloring[binli[CartesianIndex(mod1.(Tuple(I), 2))]], li[I])
+    end
+
+    return coloring
+end
+
+function subrun!(subctx, maxiters)
+    #fetch(subctx) .= Ref(zero(eltype(fetch(subctx))) .+ 1)
+    display("uiae")
+    step!(subctx)
+    #for j in 1:maxiters
+    #    step!(subctx)
+    #end
+    return subctx
+end
+
 function step!(ctx::DualTVDDState)
     alg = ctx.algorithm
     σ = ctx.algorithm.σ
@@ -76,36 +100,49 @@ function step!(ctx::DualTVDDState)
     overlap = ctx.algorithm.overlap
 
     # call run! on each cell (this can be threaded)
-    for i in eachindex(ctx.subax)
-        sax = ctx.subax[i]
-        li = LinearIndices(ctx.subax)[i]
-        sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround
-        sq = ctx.q[i] # julia-bug workaround
-
-        sg .= view(alg.problem.g, sax...)
-        if alg.parallel
-            sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...)
-        else
-            sq .= Ref(zero(eltype(sq)))
-            # contributions from previous domains
-            for (lj, saxj) in enumerate(ctx.subax)
-                ids, idsi, idsj = intersectin(CartesianIndices(sax), CartesianIndices(saxj))
-                if lj < li
-                    sq[idsi] .+= view(ctx.subctx[lj].p, idsj)
-                elseif lj > li
-                    sq[idsi] .+= theta.(Ref(ax), Ref(saxj), Ref(overlap), ids) .* view(ctx.p, ids)
+    cids = chessboard_coloring(size(ctx.subax))
+    for (color, ids) in enumerate(cids)
+
+        # prepare data g for subproblems
+        for i in ids
+            sax = ctx.subax[i]
+            li = LinearIndices(ctx.subax)[i]
+            sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround
+            sq = ctx.q[i] # julia-bug workaround
+
+            sg .= view(alg.problem.g, sax...)
+            if alg.parallel
+                sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...)
+            else
+                sq .= Ref(zero(eltype(sq)))
+                # contributions from previous domains
+                for (pcolor, pids) in enumerate(cids)
+                    # TODO: only adjacent ones needed
+                    for lj in pids
+                        saxj = ctx.subax[lj]
+                        pids, pidsi, pidsj = intersectin(CartesianIndices(sax), CartesianIndices(saxj))
+                        if pcolor < color
+                            sq[pidsi] .+= view(ctx.subctx[lj].p, pidsj)
+                        elseif pcolor > color
+                            sq[pidsi] .+= theta.(Ref(ax), Ref(saxj), Ref(overlap), pids) .* view(ctx.p, pids)
+                        end
+                    end
                 end
             end
-        end
 
-        @inline kfΛ(pw) = @inbounds sg[pw.position] + divergence(pw)
-        kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+            @inline kfΛ(pw) = @inbounds sg[pw.position] + divergence(pw)
+            kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
 
-        map!(kΛ, sg, sq)
-
-        for j in 1:alg.ninner
-            step!(ctx.subctx[i])
+            map!(kΛ, sg, sq)
         end
+
+        # actually run subalgorithms
+        ctx.subctx[ids] .= map(subrun!, deepcopy(ctx.subctx[ids]), [1 for _ in ids])
+        #ctx.subctx[ids] .= map(subrun!, deepcopy(ctx.subctx[ids]), [alg.ninner for _ in ids])
+
+        #for i in ids
+        #    subrun!(ctx.subctx[i])
+        #end
     end
 
     # aggregate (not thread-safe!)
diff --git a/src/problems.jl b/src/problems.jl
index be32f99..716c76e 100644
--- a/src/problems.jl
+++ b/src/problems.jl
@@ -47,6 +47,10 @@ function recover_u(p, prob::DualTVL1ROFOpProblem)
     return u
 end
 
+mynorm(v) = norm(vec(v))
+@inline mynorm(v::SArray{<:Any,<:SArray{<:Any,Float64,1},1}) =
+    sqrt(sum(sum(v[i] .^ 2) for i in eachindex(v)))
+
 function residual(p, prob::DualTVL1ROFOpProblem)
     d = ndims(p)
     grad = Kernel{ntuple(_->0:1, d)}(gradient)
@@ -54,7 +58,7 @@ function residual(p, prob::DualTVL1ROFOpProblem)
     u = recover_u(p, prob)
     q = map(grad, StaticKernels.extend(u, StaticKernels.ExtensionReplicate()))
 
-    res = q .- norm.(q) .* p ./ prob.λ
+    res = q .- mynorm.(q) .* p ./ prob.λ
     return sum(dot.(res, res)) / length(p)
 end
 
diff --git a/src/surrogate.jl b/src/surrogate.jl
new file mode 100644
index 0000000..dcdacf9
--- /dev/null
+++ b/src/surrogate.jl
@@ -0,0 +1,78 @@
+using LinearAlgebra: I
+
+struct SurrogateAlgorithm{P} <: Algorithm{P}
+    problem::P
+    subalg::Function
+    τ::Float64
+    function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem;
+                                τ=2*normB(problem), subalg=x->ProjectedGradient(x))
+        return new{typeof(problem)}(problem, subalg, τ)
+    end
+end
+
+struct SurrogateState{A,SP,SS,Wv,R,S,K1} <: State
+    algorithm::A
+    subproblem::SP
+    substate::SS
+
+    "scalar temporary 1"
+    rv::Wv
+    "scalar temporary 2"
+    sv::Wv
+
+    "matrix view on rv"
+    r::R
+    "matrix view on sv"
+    s::S
+
+    k1::K1
+end
+
+function init(alg::SurrogateAlgorithm)
+    prob = alg.problem
+    g = prob.g
+    d = ndims(g)
+
+    rv = zeros(eltype(g), length(g))
+    sv = zero(rv)
+
+    r = reshape(rv, size(g))
+    s = reshape(sv, size(g))
+    #s = extend(reshape(sv, size(g)), StaticKernels.ExtensionReplicate())
+
+    @inline kf1(pw) = @inbounds -divergence(pw)
+    k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
+
+    subprob = DualTVL1ROFOpProblem(copy(prob.g), I, prob.λ)
+    substate = init(alg.subalg(subprob))
+
+    return SurrogateState(alg, subprob, substate, rv, sv, r, s, k1)
+end
+
+function update_g!(subg, st::SurrogateState)
+    alg = st.algorithm
+    g = alg.problem.g
+    p = extend(fetch(st.substate), StaticKernels.ExtensionReplicate())
+
+    # r = Λ*p
+    map!(st.k1, st.r, p)
+    # s = r - g
+    st.s .= st.r - g
+    # r = r + 1/τ * B * s
+    mul!(st.rv, alg.problem.B, st.sv, -1. / alg.τ, 1.)
+
+    subg .= st.r
+end
+
+
+function step!(st::SurrogateState)
+    update_g!(st.subproblem.g, st)
+
+    for i in 1:10
+        step!(st.substate)
+    end
+
+    return st
+end
+
+fetch(st::SurrogateState) = fetch(st.substate)
-- 
GitLab