Skip to content
Snippets Groups Projects
Commit 5a3902ff authored by Stephan Hilb's avatar Stephan Hilb
Browse files

fix bugs with operator B

parent f8d92720
No related branches found
No related tags found
No related merge requests found
module DualTVDD module DualTVDD
include("types.jl") include("types.jl")
include("chambolle.jl") include("chambolle.jl")
include("dualtvdd.jl") include("dualtvdd.jl")
...@@ -8,44 +10,56 @@ include("dualtvdd.jl") ...@@ -8,44 +10,56 @@ include("dualtvdd.jl")
using Makie: heatmap using Makie: heatmap
function run() function run()
g = rand(50,50) g = zeros(20,20)
g[4:17,4:17] .= 1
#g[:size(g, 1)÷2,:] .= 1
#g = [0. 2; 1 0.] #g = [0. 2; 1 0.]
A = diagm(ones(length(g))) B = diagm(fill(100, length(g)))
α = 0.25 α = 0.1
md = DualTVDD.OpROFModel(g, A, α) md = DualTVDD.OpROFModel(g, B, α)
alg = DualTVDD.ChambolleAlgorithm() alg = DualTVDD.ChambolleAlgorithm()
ctx = DualTVDD.init(md, alg) ctx = DualTVDD.init(md, alg)
scene = heatmap(ctx.s, #scene = vbox(
colorrange=(0,1), colormap=:gray, scale_plot=false) # heatmap(ctx.s, colorrange=(0,1), colormap=:gray, scale_plot=false, show_axis=false),
# heatmap(ctx.s, colorrange=(0,1), colormap=:gray, scale_plot=false, show_axis=false),
display(scene) # )
#display(scene)
hm = last(scene) #hm = last(scene)
for i in 1:100 for i in 1:10000
step!(ctx) step!(ctx)
hm[1] = ctx.s #hm[1] = ctx.s
yield() #yield()
#sleep(0.2) #sleep(0.2)
end end
display(ctx.p)
display(recover_u!(ctx))
ctx ctx
#hm[1] = ctx.s #hm[1] = ctx.s
#yield() #yield()
end end
function rundd() function rundd()
β = 0
f = zeros(8,8) f = zeros(8,8)
f[1,:] .= 1 f[1:4,:] .= 1
#g = [0. 2; 1 0.] #g = [0. 2; 1 0.]
A = diagm(ones(length(f))) A = diagm(vcat(fill(1/2, length(f)÷2), fill(2, length(f)÷2)))
α = 0.25 B = inv(A'*A + β*I)
g = similar(f)
vec(g) .= A' * vec(f)
α = .25
md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.) md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25) alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25)
ctx = DualTVDD.init(md, alg) ctx = DualTVDD.init(md, alg)
md2 = DualTVDD.OpROFModel(f, A, α) md2 = DualTVDD.OpROFModel(g, B, α)
alg2 = DualTVDD.ChambolleAlgorithm() alg2 = DualTVDD.ChambolleAlgorithm()
ctx2 = DualTVDD.init(md2, alg2) ctx2 = DualTVDD.init(md2, alg2)
...@@ -81,5 +95,46 @@ function rundd() ...@@ -81,5 +95,46 @@ function rundd()
ctx, ctx2 ctx, ctx2
end end
function energy(ctx::DualTVDDContext)
d = ndims(ctx.p)
@inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
v = similar(ctx.g)
# v = div(p) + A'*f
map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
v .+= ctx.g
u = ctx.B * vec(v)
return sum(u .* vec(v)) / 2
end
function energy(ctx::ChambolleContext)
d = ndims(ctx.p)
@inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
v = similar(ctx.model.g)
# v = div(p) + g
map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
v .+= ctx.model.g
u = ctx.model.B * vec(v)
return sum(u .* vec(v)) / 2
end
function energy(md::DualTVDDModel, u::AbstractMatrix)
@inline kf(w) = @inbounds 1/2 * (w[0,0] - md.g[w.position])^2 +
md.λ * sqrt((w[1,0] - w[0,0])^2 + (w[0,1] - w[0,0])^2)
k = Kernel{(0:1, 0:1)}(kf, StaticKernels.ExtensionReplicate())
return sum(k, u)
end
end # module end # module
...@@ -10,19 +10,19 @@ struct DualTVDDAlgorithm{d} <: Algorithm ...@@ -10,19 +10,19 @@ struct DualTVDDAlgorithm{d} <: Algorithm
end end
end end
struct DualTVDDContext{M,A,G,d,U,V,VV,SAx,Vview,SC} struct DualTVDDContext{M,A,G,d,U,V,Vtmp,VV,SAx,SC}
model::M model::M
algorithm::A algorithm::A
"precomputed A'f" "precomputed A'f"
g::G g::G
"global dual optimization variable" "global dual optimization variable"
p::V p::V
"(A'A + βI)^(-1)" "global dual temporary variable"
ptmp::Vtmp
"precomputed (A'A + βI)^(-1)"
B::VV B::VV
"subdomain axes wrt global indices" "subdomain axes wrt global indices"
subax::SAx subax::SAx
"local views on p per subdomain"
pviews::Array{Vview,d}
"subproblem data, subg[i] == subctx[i].model.g" "subproblem data, subg[i] == subctx[i].model.g"
subg::Array{U,d} subg::Array{U,d}
"context for subproblems" "context for subproblems"
...@@ -32,36 +32,34 @@ end ...@@ -32,36 +32,34 @@ end
function init(md::DualTVDDModel, alg::DualTVDDAlgorithm) function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
d = ndims(md.f) d = ndims(md.f)
ax = axes(md.f) ax = axes(md.f)
# subdomain axes # subdomain axes
subax = subaxes(md.f, alg.M, alg.overlap) subax = subaxes(md.f, alg.M, alg.overlap)
# preallocated data for subproblems
# data for subproblems
subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)] subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)]
# locally dependent tv parameter # locally dependent tv parameter
subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)] subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i]))
for i in CartesianIndices(subax)]
# this is the global g, the local gs are getting initialized in step!()
g = reshape(md.A' * vec(md.f), size(md.f)) g = reshape(md.A' * vec(md.f), size(md.f))
# global dual variables
p = zeros(SVector{d,Float64}, size(md.f)) p = zeros(SVector{d,Float64}, size(md.f))
ptmp = extend(zeros(SVector{d,Float64}, size(md.f)), StaticKernels.ExtensionNothing())
#g[i] = md.f # precomputed global B
# TODO: initialize g per subdomain with partition function
B = inv(md.A' * md.A + md.β * I) B = inv(md.A' * md.A + md.β * I)
# create models for subproblems # create subproblem contexts
# TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too) # TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
li = LinearIndices(size(md.f)) li = LinearIndices(size(md.f))
models = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i]) submds = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
for i in CartesianIndices(subax)] for i in CartesianIndices(subax)]
subalg = ChambolleAlgorithm() subalg = ChambolleAlgorithm()
subctx = [init(submds[i], subalg) for i in CartesianIndices(subax)]
subctx = [init(models[i], subalg) for i in CartesianIndices(subax)] return DualTVDDContext(md, alg, g, p, ptmp, B, subax, subg, subctx)
return DualTVDDContext(md, alg, g, p, B, subax, subg, subg, subctx)
end end
function step!(ctx::DualTVDDContext) function step!(ctx::DualTVDDContext)
...@@ -72,30 +70,23 @@ function step!(ctx::DualTVDDContext) ...@@ -72,30 +70,23 @@ function step!(ctx::DualTVDDContext)
@inline kfΛ(w) = @inbounds -divergence_global(w) @inline kfΛ(w) = @inbounds -divergence_global(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ) = Kernel{ntuple(_->-1:1, d)}(kfΛ)
println("global p")
display(ctx.p)
# call run! on each cell (this can be threaded) # call run! on each cell (this can be threaded)
for i in eachindex(ctx.subctx) for i in eachindex(ctx.subctx)
sax = ctx.subax[i] sax = ctx.subax[i]
ci = CartesianIndices(sax) ci = CartesianIndices(sax)
# TODO: make p computation local!
# g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i} # g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i}
# subctx[i].p is used as a buffer # subctx[i].p is used as a buffer
tmp = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p ctx.ptmp .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
#tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) #tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
#ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...] #ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
tmp2 = map(, extend(tmp, StaticKernels.ExtensionNothing())) tmp2 = map(, ctx.ptmp)
ctx.subg[i] .= tmp2[sax...] ctx.subg[i] .= tmp2[sax...]
#map!(kΛ, ctx.subg[i], ctx.subctx[i].p) #map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
println("### ITERATION $i ###")
display(tmp)
#display(ctx.subctx[i].p)
display(ctx.subg[i])
ctx.subg[i] .+= ctx.g[sax...] ctx.subg[i] .+= ctx.g[sax...]
# set sensible starting value # set sensible starting value
ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p))) ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
...@@ -136,10 +127,10 @@ function recover_u!(ctx::DualTVDDContext) ...@@ -136,10 +127,10 @@ function recover_u!(ctx::DualTVDDContext)
@inline kfΛ(w) = @inbounds divergence(w) @inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ) = Kernel{ntuple(_->-1:1, d)}(kfΛ)
# u = div(p) + A'*f # v = div(p) + A'*f
map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing())) map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
v .+= ctx.g v .+= ctx.g
# u = B * u # u = B * v
mul!(vec(u), ctx.B, vec(v)) mul!(vec(u), ctx.B, vec(v))
return u return u
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment