From 97157af22eb769238026b4c29221e09a4d39a8a6 Mon Sep 17 00:00:00 2001 From: Stephan Hilb <stephan@ecshi.net> Date: Tue, 19 May 2020 00:50:28 +0200 Subject: [PATCH] got crude dd working --- src/DualTVDD.jl | 78 +++++++++++++++++++++ src/chambolle.jl | 1 + src/dualtvdd.jl | 172 +++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 245 insertions(+), 6 deletions(-) diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl index 78038b7..89225bf 100644 --- a/src/DualTVDD.jl +++ b/src/DualTVDD.jl @@ -2,6 +2,84 @@ module DualTVDD include("types.jl") include("chambolle.jl") +include("dualtvdd.jl") + + +using Makie: heatmap + +function run() + g = rand(50,50) + #g = [0. 2; 1 0.] + A = diagm(ones(length(g))) + α = 0.25 + + md = DualTVDD.OpROFModel(g, A, α) + alg = DualTVDD.ChambolleAlgorithm() + ctx = DualTVDD.init(md, alg) + + scene = heatmap(ctx.s, + colorrange=(0,1), colormap=:gray, scale_plot=false) + + display(scene) + + hm = last(scene) + for i in 1:100 + step!(ctx) + hm[1] = ctx.s + yield() + #sleep(0.2) + end + ctx + #hm[1] = ctx.s + #yield() +end + +function rundd() + f = zeros(8) + f[1,:] .= 1 + #g = [0. 2; 1 0.] + A = diagm(ones(length(f))) + α = 0.25 + + md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.) + alg = DualTVDD.DualTVDDAlgorithm(M=(2,), overlap=(2,), σ=0.25) + ctx = DualTVDD.init(md, alg) + + md2 = DualTVDD.OpROFModel(f, A, α) + alg2 = DualTVDD.ChambolleAlgorithm() + ctx2 = DualTVDD.init(md2, alg2) + + + for i in 1:150 + step!(ctx) + step!(ctx2) + end + + #scene = heatmap(ctx.s, + # colorrange=(0,1), colormap=:gray, scale_plot=false) + + #display(scene) + + #hm = last(scene) + #for i in 1:100 + # step!(ctx) + # hm[1] = ctx.s + # yield() + # #sleep(0.2) + #end + #hm[1] = ctx.s + #yield() + + println("p result") + display(ctx.p) + display(ctx2.p) + + println("u result") + display(recover_u!(ctx)) + display((recover_u!(ctx2); ctx2.s)) + + ctx, ctx2 +end end # module diff --git a/src/chambolle.jl b/src/chambolle.jl index 2844098..0c41a78 100644 --- a/src/chambolle.jl +++ b/src/chambolle.jl @@ -62,6 +62,7 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm) k1 = Kernel{ntuple(_->-1:1, d)}(kf1) @inline function kf2(pw, sw) + iszero(λ[pw.position]) && return zero(pw[z]) sgrad = alg.τ * gradient(sw) return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / λ[pw.position]) end diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl index 0f5a781..21cc1ee 100644 --- a/src/dualtvdd.jl +++ b/src/dualtvdd.jl @@ -1,24 +1,184 @@ struct DualTVDDAlgorithm{d} <: Algorithm "number of subdomains in each dimension" M::NTuple{d,Int} + "overlap in pixels per dimension" + overlap::NTuple{d,Int} "inertia parameter" σ::Float64 - function DualTVDDAlgorithm(; M, σ) - return new{length(M)}(M, σ) + function DualTVDDAlgorithm(; M, overlap, σ) + return new{length(M)}(M, overlap, σ) end end -struct DualTVDDContext{d,U,V,Vview,SC} +struct DualTVDDContext{M,A,G,d,U,V,VV,SAx,Vview,SC} + model::M + algorithm::A + "precomputed A'f" + g::G "global dual optimization variable" p::V + "(A'A + βI)^(-1)" + B::VV + "subdomain axes wrt global indices" + subax::SAx "local views on p per subdomain" pviews::Array{Vview,d} - "data for subproblems" - g::Array{U,d} + "subproblem data, subg[i] == subctx[i].model.g" + subg::Array{U,d} "context for subproblems" subctx::Array{SC,d} end -function solve(model::DualTVDDModel, algorithm::DualTVDDAlgorithm) +function init(md::DualTVDDModel, alg::DualTVDDAlgorithm) + d = ndims(md.f) + ax = axes(md.f) + # subdomain axes + subax = subaxes(md.f, alg.M, alg.overlap) + # data for subproblems + subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)] + + # locally dependent tv parameter + subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)] + + g = reshape(md.A' * vec(md.f), size(md.f)) + + p = zeros(SVector{d,Float64}, size(md.f)) + + #g[i] = md.f + + # TODO: initialize g per subdomain with partition function + + B = inv(md.A' * md.A + md.β * I) + + # create models for subproblems + # TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too) + li = LinearIndices(size(md.f)) + models = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i]) + for i in CartesianIndices(subax)] + + subalg = ChambolleAlgorithm() + + subctx = [init(models[i], subalg) for i in CartesianIndices(subax)] + + return DualTVDDContext(md, alg, g, p, B, subax, subg, subg, subctx) +end + +function step!(ctx::DualTVDDContext) + d = ndims(ctx.p) + ax = axes(ctx.p) + overlap = ctx.algorithm.overlap + + @inline kfΛ(w) = @inbounds divergence_global(w) + kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) + + println("global p") + display(ctx.p) + + # call run! on each cell (this can be threaded) + for i in eachindex(ctx.subctx) + sax = ctx.subax[i] + ci = CartesianIndices(sax) + + # g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i} + # subctx[i].p is used as a buffer + + tmp = (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p + #tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) + #ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...] + + tmp2 = map(kΛ, extend(tmp, StaticKernels.ExtensionNothing())) + ctx.subg[i] .= tmp2[sax...] + #map!(kΛ, ctx.subg[i], ctx.subctx[i].p) + + println("### ITERATION $i ###") + display(tmp) + #display(ctx.subctx[i].p) + display(ctx.subg[i]) + + ctx.subg[i] .+= ctx.g[sax...] + # set sensible starting value + ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p))) + + for j in 1:100 + step!(ctx.subctx[i]) + end + end + + # aggregate (not thread-safe!) + σ = ctx.algorithm.σ + ctx.p .*= 1 - σ + for i in CartesianIndices(ctx.subax) + ctx.p[ctx.subax[i]...] .+= σ .* ctx.subctx[i].p + end +end + +@generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T} + i0 = ntuple(_->0, N) + i1(k) = ntuple(i->Int(k==i), N) + + wi = (:(w[$(i0...)][$k] - + (isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N) + return quote + Base.@_inline_meta + return @inbounds +($(wi...)) + end +end + + + #FD.GridFunction(grid, (A'*A + β*I) \ (FD.divergence_z(p).data[:] .+ A'*f.data[:])) + +function recover_u!(ctx::DualTVDDContext) + d = ndims(ctx.g) + u = similar(ctx.g) + v = similar(ctx.g) + + @inline kfΛ(w) = @inbounds divergence(w) + kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) + + # u = div(p) + A'*f + map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing())) + v .+= ctx.g + # u = B * u + mul!(vec(u), ctx.B, vec(v)) + return u +end + +""" + theta(ax, sax, overlap, i) + +Return value of the partition function at index `i` given global axes `ax`, +subdomain axes `sax` and overlap count `overlap`. +This assumes that subdomains have size at least 2 .* overlap. +""" +theta(ax, sax, overlap, i::CartesianIndex) = prod(theta.(ax, sax, overlap, Tuple(i))) +theta(ax, sax, overlap::Int, i::Int) = + max(0., min(1., + first(ax) == first(sax) && i < first(ax) + overlap ? 1. : (i - first(sax)) / overlap, + last(ax) == last(sax) && i > last(ax) - overlap ? 1. : (last(sax) - i) / overlap)) + + +""" + subaxes(domain, pnum, overlap) + +Determine axes for all subdomains, given per dimension number of domains +`pnum` and overlap `overlap` +""" +function subaxes(domain, pnum, overlap) + overlap = 1 .+ overlap + d = ndims(domain) + tsize = size(domain) .+ (pnum .- 1) .* overlap + + psize = tsize .÷ pnum + osize = tsize .- pnum .* psize + + overhang(I, j) = I[j] == pnum[j] ? osize[j] : 0 + + indices = Array{NTuple{d, UnitRange{Int}}, d}(undef, pnum) + for I in CartesianIndices(pnum) + indices[I] = ntuple(j -> ((I[j] - 1) * psize[j] - (I[j] - 1) * overlap[j] + 1) : + ( I[j] * psize[j] - (I[j] - 1) * overlap[j] + overhang(I, j)), d) + end + @assert all(length.(sax) >= 1 .* overlap for sax in indices) + return indices end -- GitLab