From 7ee0449d167e8a587667c98294b3e2dc8aced821 Mon Sep 17 00:00:00 2001 From: Stephan Hilb <stephan@ecshi.net> Date: Wed, 27 May 2020 19:38:45 +0200 Subject: [PATCH] add almost working surrogate --- src/DualTVDD.jl | 21 ++++++++++++++------- src/dualtvdd.jl | 33 ++++++++++++++++++++++++++++----- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl index 450991b..1a08c0e 100644 --- a/src/DualTVDD.jl +++ b/src/DualTVDD.jl @@ -10,8 +10,8 @@ include("dualtvdd.jl") using Makie: heatmap function run() - g = zeros(20,20) - g[4:17,4:17] .= 1 + g = ones(20,20) + #g[4:17,4:17] .= 1 #g[:size(g, 1)÷2,:] .= 1 #g = [0. 2; 1 0.] B = diagm(fill(100, length(g))) @@ -44,19 +44,24 @@ end function rundd() β = 0 - f = zeros(8,8) - f[1:4,:] .= 1 + f = zeros(2,2) + f[1,:] .= 1 #g = [0. 2; 1 0.] - A = diagm(vcat(fill(1/2, length(f)÷2), fill(2, length(f)÷2))) + A = diagm(vcat(fill(2, length(f)÷2), fill(1, length(f)÷2))) + A = rand(length(f), length(f)) + display(A) + #A = diagm(fill(1/2, length(f))) B = inv(A'*A + β*I) + #println(norm(sqrt(B))) + g = similar(f) vec(g) .= A' * vec(f) α = .25 md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.) - alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25) + alg = DualTVDD.DualTVDDAlgorithm(M=(1,1), overlap=(1,1), σ=0.25) ctx = DualTVDD.init(md, alg) md2 = DualTVDD.OpROFModel(g, B, α) @@ -66,6 +71,8 @@ function rundd() for i in 1:1000 step!(ctx) + end + for i in 1:10000 step!(ctx2) end @@ -90,7 +97,7 @@ function rundd() println("u result") display(recover_u!(ctx)) - display((recover_u!(ctx2); ctx2.s)) + display(recover_u!(ctx2)) ctx, ctx2 end diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl index bc18ff2..1d75472 100644 --- a/src/dualtvdd.jl +++ b/src/dualtvdd.jl @@ -49,7 +49,8 @@ function init(md::DualTVDDModel, alg::DualTVDDAlgorithm) ptmp = extend(zeros(SVector{d,Float64}, size(md.f)), StaticKernels.ExtensionNothing()) # precomputed global B - B = inv(md.A' * md.A + md.β * I) + #B = inv(md.A' * md.A + md.β * I) + B = diagm(ones(length(md.f))) + md.β * I # create subproblem contexts # TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too) @@ -59,6 +60,9 @@ function init(md::DualTVDDModel, alg::DualTVDDAlgorithm) subalg = ChambolleAlgorithm() subctx = [init(submds[i], subalg) for i in CartesianIndices(subax)] + # subcontext B is identity + B = inv(md.A' * md.A + md.β * I) + return DualTVDDContext(md, alg, g, p, ptmp, B, subax, subg, subctx) end @@ -66,10 +70,13 @@ function step!(ctx::DualTVDDContext) d = ndims(ctx.p) ax = axes(ctx.p) overlap = ctx.algorithm.overlap + li = LinearIndices(size(ctx.model.f)) @inline kfΛ(w) = @inbounds -divergence_global(w) kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) + λ = 2*norm(sqrt(ctx.B))^2 # TODO: algorithm parameter + # call run! on each cell (this can be threaded) for i in eachindex(ctx.subctx) sax = ctx.subax[i] @@ -83,16 +90,32 @@ function step!(ctx::DualTVDDContext) #tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) #ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...] - tmp2 = map(kΛ, ctx.ptmp) - ctx.subg[i] .= tmp2[sax...] + ctx.subg[i] .= map(kΛ, ctx.ptmp)[sax...] #map!(kΛ, ctx.subg[i], ctx.subctx[i].p) ctx.subg[i] .+= ctx.g[sax...] # set sensible starting value ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p))) - for j in 1:100 - step!(ctx.subctx[i]) + # precomputed: B/λ * (A'f - Λ(1-θ_i)p^n) + gloc = similar(ctx.subg[i]) + vec(gloc) .= ctx.subctx[i].model.B * vec(ctx.subg[i]) + + # v_0 + ctx.ptmp .= theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)) .* ctx.p + ctx.subctx[i].p .= ctx.ptmp[sax...] + + # subcontext B is identity! + subIB = I - ctx.B[vec(li[sax...]), vec(li[sax...])]./λ + subB = ctx.B[vec(li[sax...]), vec(li[sax...])]./λ + + for j in 1:50 + subΛp = map(kΛ, ctx.subctx[i].p) + vec(ctx.subg[i]) .= subIB * vec(subΛp) .+ subB * vec(gloc) + + for k in 1:10 + step!(ctx.subctx[i]) + end end end -- GitLab