Skip to content
Snippets Groups Projects
Commit 97157af2 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

got crude dd working

parent bd6205aa
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,84 @@ module DualTVDD ...@@ -2,6 +2,84 @@ module DualTVDD
include("types.jl") include("types.jl")
include("chambolle.jl") include("chambolle.jl")
include("dualtvdd.jl")
using Makie: heatmap
function run()
g = rand(50,50)
#g = [0. 2; 1 0.]
A = diagm(ones(length(g)))
α = 0.25
md = DualTVDD.OpROFModel(g, A, α)
alg = DualTVDD.ChambolleAlgorithm()
ctx = DualTVDD.init(md, alg)
scene = heatmap(ctx.s,
colorrange=(0,1), colormap=:gray, scale_plot=false)
display(scene)
hm = last(scene)
for i in 1:100
step!(ctx)
hm[1] = ctx.s
yield()
#sleep(0.2)
end
ctx
#hm[1] = ctx.s
#yield()
end
function rundd()
f = zeros(8)
f[1,:] .= 1
#g = [0. 2; 1 0.]
A = diagm(ones(length(f)))
α = 0.25
md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
alg = DualTVDD.DualTVDDAlgorithm(M=(2,), overlap=(2,), σ=0.25)
ctx = DualTVDD.init(md, alg)
md2 = DualTVDD.OpROFModel(f, A, α)
alg2 = DualTVDD.ChambolleAlgorithm()
ctx2 = DualTVDD.init(md2, alg2)
for i in 1:150
step!(ctx)
step!(ctx2)
end
#scene = heatmap(ctx.s,
# colorrange=(0,1), colormap=:gray, scale_plot=false)
#display(scene)
#hm = last(scene)
#for i in 1:100
# step!(ctx)
# hm[1] = ctx.s
# yield()
# #sleep(0.2)
#end
#hm[1] = ctx.s
#yield()
println("p result")
display(ctx.p)
display(ctx2.p)
println("u result")
display(recover_u!(ctx))
display((recover_u!(ctx2); ctx2.s))
ctx, ctx2
end
end # module end # module
...@@ -62,6 +62,7 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm) ...@@ -62,6 +62,7 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm)
k1 = Kernel{ntuple(_->-1:1, d)}(kf1) k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
@inline function kf2(pw, sw) @inline function kf2(pw, sw)
iszero(λ[pw.position]) && return zero(pw[z])
sgrad = alg.τ * gradient(sw) sgrad = alg.τ * gradient(sw)
return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / λ[pw.position]) return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / λ[pw.position])
end end
......
struct DualTVDDAlgorithm{d} <: Algorithm struct DualTVDDAlgorithm{d} <: Algorithm
"number of subdomains in each dimension" "number of subdomains in each dimension"
M::NTuple{d,Int} M::NTuple{d,Int}
"overlap in pixels per dimension"
overlap::NTuple{d,Int}
"inertia parameter" "inertia parameter"
σ::Float64 σ::Float64
function DualTVDDAlgorithm(; M, σ) function DualTVDDAlgorithm(; M, overlap, σ)
return new{length(M)}(M, σ) return new{length(M)}(M, overlap, σ)
end end
end end
struct DualTVDDContext{d,U,V,Vview,SC} struct DualTVDDContext{M,A,G,d,U,V,VV,SAx,Vview,SC}
model::M
algorithm::A
"precomputed A'f"
g::G
"global dual optimization variable" "global dual optimization variable"
p::V p::V
"(A'A + βI)^(-1)"
B::VV
"subdomain axes wrt global indices"
subax::SAx
"local views on p per subdomain" "local views on p per subdomain"
pviews::Array{Vview,d} pviews::Array{Vview,d}
"data for subproblems" "subproblem data, subg[i] == subctx[i].model.g"
g::Array{U,d} subg::Array{U,d}
"context for subproblems" "context for subproblems"
subctx::Array{SC,d} subctx::Array{SC,d}
end end
function solve(model::DualTVDDModel, algorithm::DualTVDDAlgorithm) function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
d = ndims(md.f)
ax = axes(md.f)
# subdomain axes
subax = subaxes(md.f, alg.M, alg.overlap)
# data for subproblems
subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)]
# locally dependent tv parameter
subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)]
g = reshape(md.A' * vec(md.f), size(md.f))
p = zeros(SVector{d,Float64}, size(md.f))
#g[i] = md.f
# TODO: initialize g per subdomain with partition function
B = inv(md.A' * md.A + md.β * I)
# create models for subproblems
# TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
li = LinearIndices(size(md.f))
models = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
for i in CartesianIndices(subax)]
subalg = ChambolleAlgorithm()
subctx = [init(models[i], subalg) for i in CartesianIndices(subax)]
return DualTVDDContext(md, alg, g, p, B, subax, subg, subg, subctx)
end
function step!(ctx::DualTVDDContext)
d = ndims(ctx.p)
ax = axes(ctx.p)
overlap = ctx.algorithm.overlap
@inline kfΛ(w) = @inbounds divergence_global(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
println("global p")
display(ctx.p)
# call run! on each cell (this can be threaded)
for i in eachindex(ctx.subctx)
sax = ctx.subax[i]
ci = CartesianIndices(sax)
# g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i}
# subctx[i].p is used as a buffer
tmp = (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
#tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
#ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
tmp2 = map(, extend(tmp, StaticKernels.ExtensionNothing()))
ctx.subg[i] .= tmp2[sax...]
#map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
println("### ITERATION $i ###")
display(tmp)
#display(ctx.subctx[i].p)
display(ctx.subg[i])
ctx.subg[i] .+= ctx.g[sax...]
# set sensible starting value
ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
for j in 1:100
step!(ctx.subctx[i])
end
end
# aggregate (not thread-safe!)
σ = ctx.algorithm.σ
ctx.p .*= 1 - σ
for i in CartesianIndices(ctx.subax)
ctx.p[ctx.subax[i]...] .+= σ .* ctx.subctx[i].p
end
end
@generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
wi = (:(w[$(i0...)][$k] -
(isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N)
return quote
Base.@_inline_meta
return @inbounds +($(wi...))
end
end
#FD.GridFunction(grid, (A'*A + β*I) \ (FD.divergence_z(p).data[:] .+ A'*f.data[:]))
function recover_u!(ctx::DualTVDDContext)
d = ndims(ctx.g)
u = similar(ctx.g)
v = similar(ctx.g)
@inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
# u = div(p) + A'*f
map!(, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
v .+= ctx.g
# u = B * u
mul!(vec(u), ctx.B, vec(v))
return u
end
"""
theta(ax, sax, overlap, i)
Return value of the partition function at index `i` given global axes `ax`,
subdomain axes `sax` and overlap count `overlap`.
This assumes that subdomains have size at least 2 .* overlap.
"""
theta(ax, sax, overlap, i::CartesianIndex) = prod(theta.(ax, sax, overlap, Tuple(i)))
theta(ax, sax, overlap::Int, i::Int) =
max(0., min(1.,
first(ax) == first(sax) && i < first(ax) + overlap ? 1. : (i - first(sax)) / overlap,
last(ax) == last(sax) && i > last(ax) - overlap ? 1. : (last(sax) - i) / overlap))
"""
subaxes(domain, pnum, overlap)
Determine axes for all subdomains, given per dimension number of domains
`pnum` and overlap `overlap`
"""
function subaxes(domain, pnum, overlap)
overlap = 1 .+ overlap
d = ndims(domain)
tsize = size(domain) .+ (pnum .- 1) .* overlap
psize = tsize pnum
osize = tsize .- pnum .* psize
overhang(I, j) = I[j] == pnum[j] ? osize[j] : 0
indices = Array{NTuple{d, UnitRange{Int}}, d}(undef, pnum)
for I in CartesianIndices(pnum)
indices[I] = ntuple(j -> ((I[j] - 1) * psize[j] - (I[j] - 1) * overlap[j] + 1) :
( I[j] * psize[j] - (I[j] - 1) * overlap[j] + overhang(I, j)), d)
end
@assert all(length.(sax) >= 1 .* overlap for sax in indices)
return indices
end end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment