From 37957857460f4729a5a60134614383c48dab47c4 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Thu, 17 Dec 2020 11:54:22 +0100
Subject: [PATCH] implement global surrogate dd

---
 src/DualTVDD.jl           |   8 +-
 src/dualtvdd_surrogate.jl | 149 ++++++++++++++++++++++++++++++++++++++
 src/surrogate.jl          |   2 +-
 test/runtests.jl          |  35 ++++++++-
 4 files changed, 190 insertions(+), 4 deletions(-)
 create mode 100644 src/dualtvdd_surrogate.jl

diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index bb18d01..8ed0d67 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -1,16 +1,20 @@
 module DualTVDD
 
-export DualTVDDAlgorithm
+export DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm,
+    SurrogateAlgorithm, DualTVDDAlgorithm, DualTVDDSurrogateAlgorithm
+
 export init, step!, fetch, run
+export fetch_u, normB
 
 include("common.jl")
 include("problems.jl")
 
 include("nstep.jl")
 include("chambolle.jl")
-include("dualtvdd.jl")
 include("projgrad.jl")
 include("surrogate.jl")
+include("dualtvdd.jl")
+include("dualtvdd_surrogate.jl")
 #include("tvnewton.jl")
 
 end # module
diff --git a/src/dualtvdd_surrogate.jl b/src/dualtvdd_surrogate.jl
new file mode 100644
index 0000000..dbd2fcf
--- /dev/null
+++ b/src/dualtvdd_surrogate.jl
@@ -0,0 +1,149 @@
+using Distributed: workers
+using Outsource: Connector, outsource
+
+struct DualTVDDSurrogateAlgorithm{P,d} <: Algorithm{P}
+    problem::P
+
+    "number of subdomains in each dimension"
+    M::NTuple{d,Int}
+    "overlap in pixels per dimension"
+    overlap::NTuple{d,Int}
+    "use the non-sequential inertia update"
+    parallel::Bool
+    "inertia parameter (only when parallel)"
+    σ::Float64
+    "surrogate stepsize"
+    τ::Float64
+    "number of inner iterations"
+    ninner::Int
+    "prob -> Algorithm(::Problem, ...)"
+    subalg::Function
+    "worker ids used for distributed execution"
+    workers::Vector{Int}
+    function DualTVDDSurrogateAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, τ=2*normB(problem), subalg=x->ChambolleAlgorithm(x), workers=workers())
+        if parallel == true && σ > 1/4
+            @warn "parallel domain decomposition needs σ >= 1/4 for theoretical convergence"
+        end
+        return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, τ, ninner, subalg, workers)
+    end
+end
+
+struct DualTVDDSurrogateState{A,d,V,SAx,SC}
+    algorithm::A
+
+    "global variable"
+    p::V
+    "subdomain axes wrt global indices"
+    subax::SAx
+    "connectors to subworkers"
+    cons::Array{SC,d}
+end
+
+function init(alg::DualTVDDSurrogateAlgorithm{<:DualTVL1ROFOpProblem})
+    g = alg.problem.g
+    d = ndims(g)
+    ax = axes(g)
+
+    # subdomain axes
+    subax = subaxes(size(g), alg.M, alg.overlap)
+    # preallocated data for subproblems
+    subg = [Array{eltype(g), d}(undef, length.(x)) for x in subax]
+    # locally dependent tv parameter
+    subλ = [alg.problem.λ[subax[i]...] .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i]))
+        for i in CartesianIndices(subax)]
+
+    # TODO: generalize to SArray
+    p1type(T::Type{<:Real}) = SVector{d,T}
+    p1type(::Type{SVector{m,T}}) where {m, T} = SMatrix{m,d,T,m*d}
+
+    # global dual variable
+    p = zeros(p1type(eltype(g)), ax)
+    # local dual variable
+    subp = [zeros(p1type(eltype(g)), length.(sax)) for sax in subax]
+
+    # create subproblem contexts
+    cids = alg.parallel ? [eachindex(subax)] :
+        chessboard_coloring(size(subax))
+    cons = Array{Connector, d}(undef, size(subax))
+    for (color, sidxs) in enumerate(cids)
+        for (i, sidx) in enumerate(sidxs)
+            sax = subax[sidx]
+
+            subprob = DualTVL1ROFOpProblem(subg[sidx], I, subλ[sidx])
+            wf = subworker(alg, alg.subalg(subprob))
+            wid = alg.workers[mod1(i, length(alg.workers))]
+            cons[sidx] = outsource(wf, wid)
+        end
+    end
+
+    return DualTVDDSurrogateState(alg, p, subax, cons)
+end
+
+function step!(ctx::DualTVDDSurrogateState)
+    alg = ctx.algorithm
+    σ = ctx.algorithm.σ
+    d = ndims(ctx.p)
+    ax = axes(ctx.p)
+    overlap = ctx.algorithm.overlap
+    g = alg.problem.g
+
+    p_rem = copy(ctx.p)
+    p_don = zeros(eltype(ctx.p), size(ctx.p))
+
+    # subdomain loop (in coloring order)
+    #cids = alg.parallel ? [eachindex(ctx.subax)] :
+        #chessboard_coloring(size(ctx.subax))
+    cids = chessboard_coloring(size(ctx.subax))
+
+    for (color, ids) in enumerate(cids)
+        p_current = extend(zeros(eltype(ctx.p), size(ctx.p)),
+            StaticKernels.ExtensionNothing())
+        for i in ids
+            sax = ctx.subax[i]
+            p_current .+=
+                theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ax)) .* ctx.p
+        end
+        p_rem .-= p_current
+
+        @inline kfΛ(p) = @inbounds -divergence(p)
+        kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+        if alg.parallel
+            p_other = ctx.p .- p_current
+        else
+            p_other = p_don .+ p_rem
+        end
+        p_other = extend(p_other, StaticKernels.ExtensionNothing())
+
+        r_current = map(kΛ, p_current)
+        r_other = map(kΛ, p_other)
+
+        s = r_current .+ r_other .- g
+
+        g_current = r_current
+        mul!(vec(g_current), alg.problem.B, vec(s), -1. / alg.τ, 1.)
+
+        for pids in Iterators.partition(ids, length(alg.workers))
+            for i in pids
+                sax = ctx.subax[i]
+                put!(ctx.cons[i], view(g_current, sax...))
+            end
+        end
+
+        p_don .+= (1 .- σ) .* p_current
+        for pids in Iterators.partition(ids, length(alg.workers))
+            for i in pids
+                sax = ctx.subax[i]
+                p_i = take!(ctx.cons[i])
+                view(p_don, sax...) .+= σ .* p_i
+            end
+        end
+    end
+
+    ctx.p .= p_don
+
+
+    return ctx
+end
+
+fetch(ctx::DualTVDDSurrogateState) = ctx.p
diff --git a/src/surrogate.jl b/src/surrogate.jl
index a4bcb48..b3ee405 100644
--- a/src/surrogate.jl
+++ b/src/surrogate.jl
@@ -8,7 +8,7 @@ struct SurrogateAlgorithm{P} <: Algorithm{P}
     τ::Float64
     function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem;
                                 ninner=1,
-                                τ=2*normB(problem), subalg=x->ProjGradAlgorithm(x))
+                                τ=2*normB(problem), subalg=x->ChambolleAlgorithm(x))
         return new{typeof(problem)}(problem, subalg, ninner, τ)
     end
 end
diff --git a/test/runtests.jl b/test/runtests.jl
index 3a0b97e..ff1eb51 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,7 +2,7 @@ using Test, BenchmarkTools
 using LinearAlgebra
 using DualTVDD:
     DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, DualTVDDAlgorithm,
-    init, step!, fetch_u
+    init, step!, fetch_u, normB
 
 @testset "B = I" begin
     g = Float64[0 2; 1 0]
@@ -57,3 +57,36 @@ end
     end
     @test fetch_u(st) ≈ fetch_u(stref)
 end
+
+@testset "DualTVDDSurrogateAlgorithm" begin
+    d = 2
+    n = 5
+    ninner = 100
+    g = rand(n, n)
+    #B = Diagonal(rand(n^2))
+    B = rand(n^2, n^2) ./ n^2
+    B = I + 0.5 * (B + B')
+    M = ntuple(_->1, d)
+    overlap = ntuple(_->2, d)
+    parallel = true
+    σ = 1.0
+    # big λ is ok, since we test for inter-subdomain communication
+    prob = DualTVL1ROFOpProblem(g, B, 1000.)
+
+    algref = ChambolleAlgorithm(prob)
+    #algref = DualTVDDAlgorithm(prob; M, overlap, ninner=1, parallel, σ,
+    #    subalg = x -> SurrogateAlgorithm(x; τ=2*normB(prob), ninner))
+    alg = DualTVDDSurrogateAlgorithm(prob; M, overlap, ninner, parallel, σ,
+        subalg = x -> ChambolleAlgorithm(x))
+
+    stref = init(algref)
+    st = init(alg)
+    #@test 0 == @ballocated step!($ctx)
+    for i in 1:1000
+        step!(stref)
+    end
+    for i in 1:1000
+        step!(st)
+    end
+    @test fetch_u(st) ≈ fetch_u(stref)
+end
-- 
GitLab