Skip to content
Snippets Groups Projects
Commit 7ee0449d authored by Stephan Hilb's avatar Stephan Hilb
Browse files

add almost working surrogate

parent 5a3902ff
No related branches found
No related tags found
No related merge requests found
......@@ -10,8 +10,8 @@ include("dualtvdd.jl")
using Makie: heatmap
function run()
g = zeros(20,20)
g[4:17,4:17] .= 1
g = ones(20,20)
#g[4:17,4:17] .= 1
#g[:size(g, 1)÷2,:] .= 1
#g = [0. 2; 1 0.]
B = diagm(fill(100, length(g)))
......@@ -44,19 +44,24 @@ end
function rundd()
β = 0
f = zeros(8,8)
f[1:4,:] .= 1
f = zeros(2,2)
f[1,:] .= 1
#g = [0. 2; 1 0.]
A = diagm(vcat(fill(1/2, length(f)÷2), fill(2, length(f)÷2)))
A = diagm(vcat(fill(2, length(f)÷2), fill(1, length(f)÷2)))
A = rand(length(f), length(f))
display(A)
#A = diagm(fill(1/2, length(f)))
B = inv(A'*A + β*I)
#println(norm(sqrt(B)))
g = similar(f)
vec(g) .= A' * vec(f)
α = .25
md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25)
alg = DualTVDD.DualTVDDAlgorithm(M=(1,1), overlap=(1,1), σ=0.25)
ctx = DualTVDD.init(md, alg)
md2 = DualTVDD.OpROFModel(g, B, α)
......@@ -66,6 +71,8 @@ function rundd()
for i in 1:1000
step!(ctx)
end
for i in 1:10000
step!(ctx2)
end
......@@ -90,7 +97,7 @@ function rundd()
println("u result")
display(recover_u!(ctx))
display((recover_u!(ctx2); ctx2.s))
display(recover_u!(ctx2))
ctx, ctx2
end
......
......@@ -49,7 +49,8 @@ function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
ptmp = extend(zeros(SVector{d,Float64}, size(md.f)), StaticKernels.ExtensionNothing())
# precomputed global B
B = inv(md.A' * md.A + md.β * I)
#B = inv(md.A' * md.A + md.β * I)
B = diagm(ones(length(md.f))) + md.β * I
# create subproblem contexts
# TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
......@@ -59,6 +60,9 @@ function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
subalg = ChambolleAlgorithm()
subctx = [init(submds[i], subalg) for i in CartesianIndices(subax)]
# subcontext B is identity
B = inv(md.A' * md.A + md.β * I)
return DualTVDDContext(md, alg, g, p, ptmp, B, subax, subg, subctx)
end
......@@ -66,10 +70,13 @@ function step!(ctx::DualTVDDContext)
d = ndims(ctx.p)
ax = axes(ctx.p)
overlap = ctx.algorithm.overlap
li = LinearIndices(size(ctx.model.f))
@inline kfΛ(w) = @inbounds -divergence_global(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
λ = 2*norm(sqrt(ctx.B))^2 # TODO: algorithm parameter
# call run! on each cell (this can be threaded)
for i in eachindex(ctx.subctx)
sax = ctx.subax[i]
......@@ -83,16 +90,32 @@ function step!(ctx::DualTVDDContext)
#tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
#ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
tmp2 = map(, ctx.ptmp)
ctx.subg[i] .= tmp2[sax...]
ctx.subg[i] .= map(, ctx.ptmp)[sax...]
#map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
ctx.subg[i] .+= ctx.g[sax...]
# set sensible starting value
ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
for j in 1:100
step!(ctx.subctx[i])
# precomputed: B/λ * (A'f - Λ(1-θ_i)p^n)
gloc = similar(ctx.subg[i])
vec(gloc) .= ctx.subctx[i].model.B * vec(ctx.subg[i])
# v_0
ctx.ptmp .= theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)) .* ctx.p
ctx.subctx[i].p .= ctx.ptmp[sax...]
# subcontext B is identity!
subIB = I - ctx.B[vec(li[sax...]), vec(li[sax...])]./λ
subB = ctx.B[vec(li[sax...]), vec(li[sax...])]./λ
for j in 1:50
subΛp = map(, ctx.subctx[i].p)
vec(ctx.subg[i]) .= subIB * vec(subΛp) .+ subB * vec(gloc)
for k in 1:10
step!(ctx.subctx[i])
end
end
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment