From 97157af22eb769238026b4c29221e09a4d39a8a6 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Tue, 19 May 2020 00:50:28 +0200
Subject: [PATCH] got crude dd working
---
src/DualTVDD.jl | 78 +++++++++++++++++++++
src/chambolle.jl | 1 +
src/dualtvdd.jl | 172 +++++++++++++++++++++++++++++++++++++++++++++--
3 files changed, 245 insertions(+), 6 deletions(-)
diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index 78038b7..89225bf 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -2,6 +2,84 @@ module DualTVDD
include("types.jl")
include("chambolle.jl")
+include("dualtvdd.jl")
+
+
+using Makie: heatmap
+
+function run()
+ g = rand(50,50)
+ #g = [0. 2; 1 0.]
+ A = diagm(ones(length(g)))
+ α = 0.25
+
+ md = DualTVDD.OpROFModel(g, A, α)
+ alg = DualTVDD.ChambolleAlgorithm()
+ ctx = DualTVDD.init(md, alg)
+
+ scene = heatmap(ctx.s,
+ colorrange=(0,1), colormap=:gray, scale_plot=false)
+
+ display(scene)
+
+ hm = last(scene)
+ for i in 1:100
+ step!(ctx)
+ hm[1] = ctx.s
+ yield()
+ #sleep(0.2)
+ end
+ ctx
+ #hm[1] = ctx.s
+ #yield()
+end
+
+function rundd()
+ f = zeros(8)
+ f[1,:] .= 1
+ #g = [0. 2; 1 0.]
+ A = diagm(ones(length(f)))
+ α = 0.25
+
+ md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
+ alg = DualTVDD.DualTVDDAlgorithm(M=(2,), overlap=(2,), σ=0.25)
+ ctx = DualTVDD.init(md, alg)
+
+ md2 = DualTVDD.OpROFModel(f, A, α)
+ alg2 = DualTVDD.ChambolleAlgorithm()
+ ctx2 = DualTVDD.init(md2, alg2)
+
+
+ for i in 1:150
+ step!(ctx)
+ step!(ctx2)
+ end
+
+ #scene = heatmap(ctx.s,
+ # colorrange=(0,1), colormap=:gray, scale_plot=false)
+
+ #display(scene)
+
+ #hm = last(scene)
+ #for i in 1:100
+ # step!(ctx)
+ # hm[1] = ctx.s
+ # yield()
+ # #sleep(0.2)
+ #end
+ #hm[1] = ctx.s
+ #yield()
+
+ println("p result")
+ display(ctx.p)
+ display(ctx2.p)
+
+ println("u result")
+ display(recover_u!(ctx))
+ display((recover_u!(ctx2); ctx2.s))
+
+ ctx, ctx2
+end
end # module
diff --git a/src/chambolle.jl b/src/chambolle.jl
index 2844098..0c41a78 100644
--- a/src/chambolle.jl
+++ b/src/chambolle.jl
@@ -62,6 +62,7 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm)
k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
@inline function kf2(pw, sw)
+ iszero(λ[pw.position]) && return zero(pw[z])
sgrad = alg.τ * gradient(sw)
return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / λ[pw.position])
end
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index 0f5a781..21cc1ee 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -1,24 +1,184 @@
struct DualTVDDAlgorithm{d} <: Algorithm
"number of subdomains in each dimension"
M::NTuple{d,Int}
+ "overlap in pixels per dimension"
+ overlap::NTuple{d,Int}
"inertia parameter"
σ::Float64
- function DualTVDDAlgorithm(; M, σ)
- return new{length(M)}(M, σ)
+ function DualTVDDAlgorithm(; M, overlap, σ)
+ return new{length(M)}(M, overlap, σ)
end
end
-struct DualTVDDContext{d,U,V,Vview,SC}
+struct DualTVDDContext{M,A,G,d,U,V,VV,SAx,Vview,SC}
+ model::M
+ algorithm::A
+ "precomputed A'f"
+ g::G
"global dual optimization variable"
p::V
+ "(A'A + βI)^(-1)"
+ B::VV
+ "subdomain axes wrt global indices"
+ subax::SAx
"local views on p per subdomain"
pviews::Array{Vview,d}
- "data for subproblems"
- g::Array{U,d}
+ "subproblem data, subg[i] == subctx[i].model.g"
+ subg::Array{U,d}
"context for subproblems"
subctx::Array{SC,d}
end
-function solve(model::DualTVDDModel, algorithm::DualTVDDAlgorithm)
+function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
+ d = ndims(md.f)
+ ax = axes(md.f)
+ # subdomain axes
+ subax = subaxes(md.f, alg.M, alg.overlap)
+ # data for subproblems
+ subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)]
+
+ # locally dependent tv parameter
+ subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)]
+
+ g = reshape(md.A' * vec(md.f), size(md.f))
+
+ p = zeros(SVector{d,Float64}, size(md.f))
+
+ #g[i] = md.f
+
+ # TODO: initialize g per subdomain with partition function
+
+ B = inv(md.A' * md.A + md.β * I)
+
+ # create models for subproblems
+ # TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
+ li = LinearIndices(size(md.f))
+ models = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
+ for i in CartesianIndices(subax)]
+
+ subalg = ChambolleAlgorithm()
+
+ subctx = [init(models[i], subalg) for i in CartesianIndices(subax)]
+
+ return DualTVDDContext(md, alg, g, p, B, subax, subg, subg, subctx)
+end
+
+function step!(ctx::DualTVDDContext)
+ d = ndims(ctx.p)
+ ax = axes(ctx.p)
+ overlap = ctx.algorithm.overlap
+
+ @inline kfΛ(w) = @inbounds divergence_global(w)
+ kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+ println("global p")
+ display(ctx.p)
+
+ # call run! on each cell (this can be threaded)
+ for i in eachindex(ctx.subctx)
+ sax = ctx.subax[i]
+ ci = CartesianIndices(sax)
+
+ # g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i}
+ # subctx[i].p is used as a buffer
+
+ tmp = (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
+ #tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
+ #ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
+
+ tmp2 = map(kΛ, extend(tmp, StaticKernels.ExtensionNothing()))
+ ctx.subg[i] .= tmp2[sax...]
+ #map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
+
+ println("### ITERATION $i ###")
+ display(tmp)
+ #display(ctx.subctx[i].p)
+ display(ctx.subg[i])
+
+ ctx.subg[i] .+= ctx.g[sax...]
+ # set sensible starting value
+ ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
+
+ for j in 1:100
+ step!(ctx.subctx[i])
+ end
+ end
+
+ # aggregate (not thread-safe!)
+ σ = ctx.algorithm.σ
+ ctx.p .*= 1 - σ
+ for i in CartesianIndices(ctx.subax)
+ ctx.p[ctx.subax[i]...] .+= σ .* ctx.subctx[i].p
+ end
+end
+
+@generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
+ i0 = ntuple(_->0, N)
+ i1(k) = ntuple(i->Int(k==i), N)
+
+ wi = (:(w[$(i0...)][$k] -
+ (isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N)
+ return quote
+ Base.@_inline_meta
+ return @inbounds +($(wi...))
+ end
+end
+
+
+ #FD.GridFunction(grid, (A'*A + β*I) \ (FD.divergence_z(p).data[:] .+ A'*f.data[:]))
+
+function recover_u!(ctx::DualTVDDContext)
+ d = ndims(ctx.g)
+ u = similar(ctx.g)
+ v = similar(ctx.g)
+
+ @inline kfΛ(w) = @inbounds divergence(w)
+ kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+ # u = div(p) + A'*f
+ map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
+ v .+= ctx.g
+ # u = B * u
+ mul!(vec(u), ctx.B, vec(v))
+ return u
+end
+
+"""
+ theta(ax, sax, overlap, i)
+
+Return value of the partition function at index `i` given global axes `ax`,
+subdomain axes `sax` and overlap count `overlap`.
+This assumes that subdomains have size at least 2 .* overlap.
+"""
+theta(ax, sax, overlap, i::CartesianIndex) = prod(theta.(ax, sax, overlap, Tuple(i)))
+theta(ax, sax, overlap::Int, i::Int) =
+ max(0., min(1.,
+ first(ax) == first(sax) && i < first(ax) + overlap ? 1. : (i - first(sax)) / overlap,
+ last(ax) == last(sax) && i > last(ax) - overlap ? 1. : (last(sax) - i) / overlap))
+
+
+"""
+ subaxes(domain, pnum, overlap)
+
+Determine axes for all subdomains, given per dimension number of domains
+`pnum` and overlap `overlap`
+"""
+function subaxes(domain, pnum, overlap)
+ overlap = 1 .+ overlap
+ d = ndims(domain)
+ tsize = size(domain) .+ (pnum .- 1) .* overlap
+
+ psize = tsize .÷ pnum
+ osize = tsize .- pnum .* psize
+
+ overhang(I, j) = I[j] == pnum[j] ? osize[j] : 0
+
+ indices = Array{NTuple{d, UnitRange{Int}}, d}(undef, pnum)
+ for I in CartesianIndices(pnum)
+ indices[I] = ntuple(j -> ((I[j] - 1) * psize[j] - (I[j] - 1) * overlap[j] + 1) :
+ ( I[j] * psize[j] - (I[j] - 1) * overlap[j] + overhang(I, j)), d)
+ end
+ @assert all(length.(sax) >= 1 .* overlap for sax in indices)
+ return indices
end
--
GitLab