Skip to content
Snippets Groups Projects
Commit 5a3902ff authored by Stephan Hilb's avatar Stephan Hilb
Browse files

fix bugs with operator B

parent f8d92720
No related branches found
No related tags found
No related merge requests found
module DualTVDD
include("types.jl")
include("chambolle.jl")
include("dualtvdd.jl")
......@@ -8,44 +10,56 @@ include("dualtvdd.jl")
using Makie: heatmap
function run()
g = rand(50,50)
g = zeros(20,20)
g[4:17,4:17] .= 1
#g[:size(g, 1)÷2,:] .= 1
#g = [0. 2; 1 0.]
A = diagm(ones(length(g)))
α = 0.25
B = diagm(fill(100, length(g)))
α = 0.1
md = DualTVDD.OpROFModel(g, A, α)
md = DualTVDD.OpROFModel(g, B, α)
alg = DualTVDD.ChambolleAlgorithm()
ctx = DualTVDD.init(md, alg)
scene = heatmap(ctx.s,
colorrange=(0,1), colormap=:gray, scale_plot=false)
display(scene)
#scene = vbox(
# heatmap(ctx.s, colorrange=(0,1), colormap=:gray, scale_plot=false, show_axis=false),
# heatmap(ctx.s, colorrange=(0,1), colormap=:gray, scale_plot=false, show_axis=false),
# )
#display(scene)
hm = last(scene)
for i in 1:100
#hm = last(scene)
for i in 1:10000
step!(ctx)
hm[1] = ctx.s
yield()
#hm[1] = ctx.s
#yield()
#sleep(0.2)
end
display(ctx.p)
display(recover_u!(ctx))
ctx
#hm[1] = ctx.s
#yield()
end
function rundd()
β = 0
f = zeros(8,8)
f[1,:] .= 1
f[1:4,:] .= 1
#g = [0. 2; 1 0.]
A = diagm(ones(length(f)))
α = 0.25
A = diagm(vcat(fill(1/2, length(f)÷2), fill(2, length(f)÷2)))
B = inv(A'*A + β*I)
g = similar(f)
vec(g) .= A' * vec(f)
α = .25
md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25)
ctx = DualTVDD.init(md, alg)
md2 = DualTVDD.OpROFModel(f, A, α)
md2 = DualTVDD.OpROFModel(g, B, α)
alg2 = DualTVDD.ChambolleAlgorithm()
ctx2 = DualTVDD.init(md2, alg2)
......@@ -81,5 +95,46 @@ function rundd()
ctx, ctx2
end
function energy(ctx::DualTVDDContext)
d = ndims(ctx.p)
@inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
v = similar(ctx.g)
# v = div(p) + A'*f
map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
v .+= ctx.g
u = ctx.B * vec(v)
return sum(u .* vec(v)) / 2
end
function energy(ctx::ChambolleContext)
d = ndims(ctx.p)
@inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
v = similar(ctx.model.g)
# v = div(p) + g
map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
v .+= ctx.model.g
u = ctx.model.B * vec(v)
return sum(u .* vec(v)) / 2
end
function energy(md::DualTVDDModel, u::AbstractMatrix)
@inline kf(w) = @inbounds 1/2 * (w[0,0] - md.g[w.position])^2 +
md.λ * sqrt((w[1,0] - w[0,0])^2 + (w[0,1] - w[0,0])^2)
k = Kernel{(0:1, 0:1)}(kf, StaticKernels.ExtensionReplicate())
return sum(k, u)
end
end # module
......@@ -10,19 +10,19 @@ struct DualTVDDAlgorithm{d} <: Algorithm
end
end
struct DualTVDDContext{M,A,G,d,U,V,VV,SAx,Vview,SC}
struct DualTVDDContext{M,A,G,d,U,V,Vtmp,VV,SAx,SC}
model::M
algorithm::A
"precomputed A'f"
g::G
"global dual optimization variable"
p::V
"(A'A + βI)^(-1)"
"global dual temporary variable"
ptmp::Vtmp
"precomputed (A'A + βI)^(-1)"
B::VV
"subdomain axes wrt global indices"
subax::SAx
"local views on p per subdomain"
pviews::Array{Vview,d}
"subproblem data, subg[i] == subctx[i].model.g"
subg::Array{U,d}
"context for subproblems"
......@@ -32,36 +32,34 @@ end
function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
d = ndims(md.f)
ax = axes(md.f)
# subdomain axes
subax = subaxes(md.f, alg.M, alg.overlap)
# data for subproblems
# preallocated data for subproblems
subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)]
# locally dependent tv parameter
subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)]
subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i]))
for i in CartesianIndices(subax)]
# this is the global g, the local gs are getting initialized in step!()
g = reshape(md.A' * vec(md.f), size(md.f))
# global dual variables
p = zeros(SVector{d,Float64}, size(md.f))
ptmp = extend(zeros(SVector{d,Float64}, size(md.f)), StaticKernels.ExtensionNothing())
#g[i] = md.f
# TODO: initialize g per subdomain with partition function
# precomputed global B
B = inv(md.A' * md.A + md.β * I)
# create models for subproblems
# create subproblem contexts
# TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
li = LinearIndices(size(md.f))
models = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
submds = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
for i in CartesianIndices(subax)]
subalg = ChambolleAlgorithm()
subctx = [init(submds[i], subalg) for i in CartesianIndices(subax)]
subctx = [init(models[i], subalg) for i in CartesianIndices(subax)]
return DualTVDDContext(md, alg, g, p, B, subax, subg, subg, subctx)
return DualTVDDContext(md, alg, g, p, ptmp, B, subax, subg, subctx)
end
function step!(ctx::DualTVDDContext)
......@@ -72,30 +70,23 @@ function step!(ctx::DualTVDDContext)
@inline kfΛ(w) = @inbounds -divergence_global(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
println("global p")
display(ctx.p)
# call run! on each cell (this can be threaded)
for i in eachindex(ctx.subctx)
sax = ctx.subax[i]
ci = CartesianIndices(sax)
# TODO: make p computation local!
# g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i}
# subctx[i].p is used as a buffer
tmp = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
ctx.ptmp .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
#tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
#ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
tmp2 = map(, extend(tmp, StaticKernels.ExtensionNothing()))
tmp2 = map(, ctx.ptmp)
ctx.subg[i] .= tmp2[sax...]
#map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
println("### ITERATION $i ###")
display(tmp)
#display(ctx.subctx[i].p)
display(ctx.subg[i])
ctx.subg[i] .+= ctx.g[sax...]
# set sensible starting value
ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
......@@ -136,10 +127,10 @@ function recover_u!(ctx::DualTVDDContext)
@inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
# u = div(p) + A'*f
# v = div(p) + A'*f
map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
v .+= ctx.g
# u = B * u
# u = B * v
mul!(vec(u), ctx.B, vec(v))
return u
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment