From 97157af22eb769238026b4c29221e09a4d39a8a6 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Tue, 19 May 2020 00:50:28 +0200
Subject: [PATCH] got crude dd working

---
 src/DualTVDD.jl  |  78 +++++++++++++++++++++
 src/chambolle.jl |   1 +
 src/dualtvdd.jl  | 172 +++++++++++++++++++++++++++++++++++++++++++++--
 3 files changed, 245 insertions(+), 6 deletions(-)

diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index 78038b7..89225bf 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -2,6 +2,84 @@ module DualTVDD
 
 include("types.jl")
 include("chambolle.jl")
+include("dualtvdd.jl")
+
+
+using Makie: heatmap
+
+function run()
+    g = rand(50,50)
+    #g = [0. 2; 1 0.]
+    A = diagm(ones(length(g)))
+    α = 0.25
+
+    md = DualTVDD.OpROFModel(g, A, α)
+    alg = DualTVDD.ChambolleAlgorithm()
+    ctx = DualTVDD.init(md, alg)
+
+    scene = heatmap(ctx.s,
+        colorrange=(0,1), colormap=:gray, scale_plot=false)
+
+    display(scene)
+
+    hm = last(scene)
+    for i in 1:100
+        step!(ctx)
+        hm[1] = ctx.s
+        yield()
+        #sleep(0.2)
+    end
+    ctx
+    #hm[1] = ctx.s
+    #yield()
+end
+
+function rundd()
+    f = zeros(8)
+    f[1,:] .= 1
+    #g = [0. 2; 1 0.]
+    A = diagm(ones(length(f)))
+    α = 0.25
+
+    md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
+    alg = DualTVDD.DualTVDDAlgorithm(M=(2,), overlap=(2,), σ=0.25)
+    ctx = DualTVDD.init(md, alg)
+
+    md2 = DualTVDD.OpROFModel(f, A, α)
+    alg2 = DualTVDD.ChambolleAlgorithm()
+    ctx2 = DualTVDD.init(md2, alg2)
+
+
+    for i in 1:150
+        step!(ctx)
+        step!(ctx2)
+    end
+
+    #scene = heatmap(ctx.s,
+    #    colorrange=(0,1), colormap=:gray, scale_plot=false)
+
+    #display(scene)
+
+    #hm = last(scene)
+    #for i in 1:100
+    #    step!(ctx)
+    #    hm[1] = ctx.s
+    #    yield()
+    #    #sleep(0.2)
+    #end
+    #hm[1] = ctx.s
+    #yield()
+
+    println("p result")
+    display(ctx.p)
+    display(ctx2.p)
+
+    println("u result")
+    display(recover_u!(ctx))
+    display((recover_u!(ctx2); ctx2.s))
+
+    ctx, ctx2
+end
 
 
 end # module
diff --git a/src/chambolle.jl b/src/chambolle.jl
index 2844098..0c41a78 100644
--- a/src/chambolle.jl
+++ b/src/chambolle.jl
@@ -62,6 +62,7 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm)
     k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
 
     @inline function kf2(pw, sw)
+        iszero(λ[pw.position]) && return zero(pw[z])
         sgrad = alg.τ * gradient(sw)
         return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / λ[pw.position])
     end
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index 0f5a781..21cc1ee 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -1,24 +1,184 @@
 struct DualTVDDAlgorithm{d} <: Algorithm
     "number of subdomains in each dimension"
     M::NTuple{d,Int}
+    "overlap in pixels per dimension"
+    overlap::NTuple{d,Int}
     "inertia parameter"
     σ::Float64
-    function DualTVDDAlgorithm(; M, σ)
-        return new{length(M)}(M, σ)
+    function DualTVDDAlgorithm(; M, overlap, σ)
+        return new{length(M)}(M, overlap, σ)
     end
 end
 
-struct DualTVDDContext{d,U,V,Vview,SC}
+struct DualTVDDContext{M,A,G,d,U,V,VV,SAx,Vview,SC}
+    model::M
+    algorithm::A
+    "precomputed A'f"
+    g::G
     "global dual optimization variable"
     p::V
+    "(A'A + βI)^(-1)"
+    B::VV
+    "subdomain axes wrt global indices"
+    subax::SAx
     "local views on p per subdomain"
     pviews::Array{Vview,d}
-    "data for subproblems"
-    g::Array{U,d}
+    "subproblem data, subg[i] == subctx[i].model.g"
+    subg::Array{U,d}
     "context for subproblems"
     subctx::Array{SC,d}
 end
 
-function solve(model::DualTVDDModel, algorithm::DualTVDDAlgorithm)
+function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
+    d = ndims(md.f)
+    ax = axes(md.f)
+    # subdomain axes
+    subax = subaxes(md.f, alg.M, alg.overlap)
 
+    # data for subproblems
+    subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)]
+
+    # locally dependent tv parameter
+    subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)]
+
+    g = reshape(md.A' * vec(md.f), size(md.f))
+
+    p = zeros(SVector{d,Float64}, size(md.f))
+
+    #g[i] = md.f
+
+    # TODO: initialize g per subdomain with partition function
+
+    B = inv(md.A' * md.A + md.β * I)
+
+    # create models for subproblems
+    # TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
+    li = LinearIndices(size(md.f))
+    models = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
+        for i in CartesianIndices(subax)]
+
+    subalg = ChambolleAlgorithm()
+
+    subctx = [init(models[i], subalg) for i in CartesianIndices(subax)]
+
+    return DualTVDDContext(md, alg, g, p, B, subax, subg, subg, subctx)
+end
+
+function step!(ctx::DualTVDDContext)
+    d = ndims(ctx.p)
+    ax = axes(ctx.p)
+    overlap = ctx.algorithm.overlap
+
+    @inline kfΛ(w) = @inbounds divergence_global(w)
+    kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+    println("global p")
+    display(ctx.p)
+
+    # call run! on each cell (this can be threaded)
+    for i in eachindex(ctx.subctx)
+        sax = ctx.subax[i]
+        ci = CartesianIndices(sax)
+
+        # g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i}
+        # subctx[i].p is used as a buffer
+
+        tmp = (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
+        #tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
+        #ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
+
+        tmp2 = map(kΛ, extend(tmp, StaticKernels.ExtensionNothing()))
+        ctx.subg[i] .= tmp2[sax...]
+        #map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
+
+        println("### ITERATION $i ###")
+        display(tmp)
+        #display(ctx.subctx[i].p)
+        display(ctx.subg[i])
+
+        ctx.subg[i] .+= ctx.g[sax...]
+        # set sensible starting value
+        ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
+
+        for j in 1:100
+            step!(ctx.subctx[i])
+        end
+    end
+
+    # aggregate (not thread-safe!)
+    σ = ctx.algorithm.σ
+    ctx.p .*= 1 - σ
+    for i in CartesianIndices(ctx.subax)
+        ctx.p[ctx.subax[i]...] .+= σ .* ctx.subctx[i].p
+    end
+end
+
+@generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
+    i0 = ntuple(_->0, N)
+    i1(k) = ntuple(i->Int(k==i), N)
+
+    wi = (:(w[$(i0...)][$k] -
+            (isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N)
+    return quote
+        Base.@_inline_meta
+        return @inbounds +($(wi...))
+    end
+end
+
+
+  #FD.GridFunction(grid, (A'*A + β*I) \ (FD.divergence_z(p).data[:] .+ A'*f.data[:]))
+
+function recover_u!(ctx::DualTVDDContext)
+    d = ndims(ctx.g)
+    u = similar(ctx.g)
+    v = similar(ctx.g)
+
+    @inline kfΛ(w) = @inbounds divergence(w)
+    kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+    # u = div(p) + A'*f
+    map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
+    v .+= ctx.g
+    # u = B * u
+    mul!(vec(u), ctx.B, vec(v))
+    return u
+end
+
+"""
+    theta(ax, sax, overlap, i)
+
+Return value of the partition function at index `i` given global axes `ax`,
+subdomain axes `sax` and overlap count `overlap`.
+This assumes that subdomains have size at least 2 .* overlap.
+"""
+theta(ax, sax, overlap, i::CartesianIndex) = prod(theta.(ax, sax, overlap, Tuple(i)))
+theta(ax, sax, overlap::Int, i::Int) =
+    max(0., min(1.,
+        first(ax) == first(sax) && i < first(ax) + overlap ? 1. : (i - first(sax)) / overlap,
+        last(ax) == last(sax) && i > last(ax) - overlap ? 1. : (last(sax) - i) / overlap))
+
+
+"""
+    subaxes(domain, pnum, overlap)
+
+Determine axes for all subdomains, given per dimension number of domains
+`pnum` and overlap `overlap`
+"""
+function subaxes(domain, pnum, overlap)
+    overlap = 1 .+ overlap
+    d = ndims(domain)
+    tsize = size(domain) .+ (pnum .- 1) .* overlap
+
+    psize = tsize .÷ pnum
+    osize = tsize .- pnum .* psize
+
+    overhang(I, j) = I[j] == pnum[j] ? osize[j] : 0
+
+    indices = Array{NTuple{d, UnitRange{Int}}, d}(undef, pnum)
+    for I in CartesianIndices(pnum)
+        indices[I] = ntuple(j -> ((I[j] - 1) * psize[j] - (I[j] - 1) * overlap[j] + 1) :
+                                 (      I[j] * psize[j] - (I[j] - 1) * overlap[j] + overhang(I, j)), d)
+    end
+    @assert all(length.(sax) >= 1 .* overlap for sax in indices)
+    return indices
 end
-- 
GitLab