From 7ee0449d167e8a587667c98294b3e2dc8aced821 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Wed, 27 May 2020 19:38:45 +0200
Subject: [PATCH] add almost working surrogate

---
 src/DualTVDD.jl | 21 ++++++++++++++-------
 src/dualtvdd.jl | 33 ++++++++++++++++++++++++++++-----
 2 files changed, 42 insertions(+), 12 deletions(-)

diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index 450991b..1a08c0e 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -10,8 +10,8 @@ include("dualtvdd.jl")
 using Makie: heatmap
 
 function run()
-    g = zeros(20,20)
-    g[4:17,4:17] .= 1
+    g = ones(20,20)
+    #g[4:17,4:17] .= 1
     #g[:size(g, 1)÷2,:] .= 1
     #g = [0. 2; 1 0.]
     B = diagm(fill(100, length(g)))
@@ -44,19 +44,24 @@ end
 
 function rundd()
     β = 0
-    f = zeros(8,8)
-    f[1:4,:] .= 1
+    f = zeros(2,2)
+    f[1,:] .= 1
     #g = [0. 2; 1 0.]
-    A = diagm(vcat(fill(1/2, length(f)÷2), fill(2, length(f)÷2)))
+    A = diagm(vcat(fill(2, length(f)÷2), fill(1, length(f)÷2)))
+    A = rand(length(f), length(f))
+    display(A)
+    #A = diagm(fill(1/2, length(f)))
     B = inv(A'*A + β*I)
 
+    #println(norm(sqrt(B)))
+
     g = similar(f)
     vec(g) .= A' * vec(f)
 
     α = .25
 
     md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
-    alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25)
+    alg = DualTVDD.DualTVDDAlgorithm(M=(1,1), overlap=(1,1), σ=0.25)
     ctx = DualTVDD.init(md, alg)
 
     md2 = DualTVDD.OpROFModel(g, B, α)
@@ -66,6 +71,8 @@ function rundd()
 
     for i in 1:1000
         step!(ctx)
+    end
+    for i in 1:10000
         step!(ctx2)
     end
 
@@ -90,7 +97,7 @@ function rundd()
 
     println("u result")
     display(recover_u!(ctx))
-    display((recover_u!(ctx2); ctx2.s))
+    display(recover_u!(ctx2))
 
     ctx, ctx2
 end
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index bc18ff2..1d75472 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -49,7 +49,8 @@ function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
     ptmp = extend(zeros(SVector{d,Float64}, size(md.f)), StaticKernels.ExtensionNothing())
 
     # precomputed global B
-    B = inv(md.A' * md.A + md.β * I)
+    #B = inv(md.A' * md.A + md.β * I)
+    B = diagm(ones(length(md.f))) + md.β * I
 
     # create subproblem contexts
     # TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
@@ -59,6 +60,9 @@ function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
     subalg = ChambolleAlgorithm()
     subctx = [init(submds[i], subalg) for i in CartesianIndices(subax)]
 
+    # subcontext B is identity
+    B = inv(md.A' * md.A + md.β * I)
+
     return DualTVDDContext(md, alg, g, p, ptmp, B, subax, subg, subctx)
 end
 
@@ -66,10 +70,13 @@ function step!(ctx::DualTVDDContext)
     d = ndims(ctx.p)
     ax = axes(ctx.p)
     overlap = ctx.algorithm.overlap
+    li = LinearIndices(size(ctx.model.f))
 
     @inline kfΛ(w) = @inbounds -divergence_global(w)
     kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
 
+    λ = 2*norm(sqrt(ctx.B))^2 # TODO: algorithm parameter
+
     # call run! on each cell (this can be threaded)
     for i in eachindex(ctx.subctx)
         sax = ctx.subax[i]
@@ -83,16 +90,32 @@ function step!(ctx::DualTVDDContext)
         #tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
         #ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
 
-        tmp2 = map(kΛ, ctx.ptmp)
-        ctx.subg[i] .= tmp2[sax...]
+        ctx.subg[i] .= map(kΛ, ctx.ptmp)[sax...]
         #map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
 
         ctx.subg[i] .+= ctx.g[sax...]
         # set sensible starting value
         ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
 
-        for j in 1:100
-            step!(ctx.subctx[i])
+        # precomputed: B/λ * (A'f - Λ(1-θ_i)p^n)
+        gloc = similar(ctx.subg[i])
+        vec(gloc) .= ctx.subctx[i].model.B * vec(ctx.subg[i])
+
+        # v_0
+        ctx.ptmp .= theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)) .* ctx.p
+        ctx.subctx[i].p .= ctx.ptmp[sax...]
+
+        # subcontext B is identity!
+        subIB = I - ctx.B[vec(li[sax...]), vec(li[sax...])]./λ
+        subB = ctx.B[vec(li[sax...]), vec(li[sax...])]./λ
+
+        for j in 1:50
+            subΛp = map(kΛ, ctx.subctx[i].p)
+            vec(ctx.subg[i]) .= subIB * vec(subΛp) .+ subB * vec(gloc)
+
+            for k in 1:10
+                step!(ctx.subctx[i])
+            end
         end
     end
 
-- 
GitLab