From 37957857460f4729a5a60134614383c48dab47c4 Mon Sep 17 00:00:00 2001 From: Stephan Hilb <stephan@ecshi.net> Date: Thu, 17 Dec 2020 11:54:22 +0100 Subject: [PATCH] implement global surrogate dd --- src/DualTVDD.jl | 8 +- src/dualtvdd_surrogate.jl | 149 ++++++++++++++++++++++++++++++++++++++ src/surrogate.jl | 2 +- test/runtests.jl | 35 ++++++++- 4 files changed, 190 insertions(+), 4 deletions(-) create mode 100644 src/dualtvdd_surrogate.jl diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl index bb18d01..8ed0d67 100644 --- a/src/DualTVDD.jl +++ b/src/DualTVDD.jl @@ -1,16 +1,20 @@ module DualTVDD -export DualTVDDAlgorithm +export DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, + SurrogateAlgorithm, DualTVDDAlgorithm, DualTVDDSurrogateAlgorithm + export init, step!, fetch, run +export fetch_u, normB include("common.jl") include("problems.jl") include("nstep.jl") include("chambolle.jl") -include("dualtvdd.jl") include("projgrad.jl") include("surrogate.jl") +include("dualtvdd.jl") +include("dualtvdd_surrogate.jl") #include("tvnewton.jl") end # module diff --git a/src/dualtvdd_surrogate.jl b/src/dualtvdd_surrogate.jl new file mode 100644 index 0000000..dbd2fcf --- /dev/null +++ b/src/dualtvdd_surrogate.jl @@ -0,0 +1,149 @@ +using Distributed: workers +using Outsource: Connector, outsource + +struct DualTVDDSurrogateAlgorithm{P,d} <: Algorithm{P} + problem::P + + "number of subdomains in each dimension" + M::NTuple{d,Int} + "overlap in pixels per dimension" + overlap::NTuple{d,Int} + "use the non-sequential inertia update" + parallel::Bool + "inertia parameter (only when parallel)" + σ::Float64 + "surrogate stepsize" + τ::Float64 + "number of inner iterations" + ninner::Int + "prob -> Algorithm(::Problem, ...)" + subalg::Function + "worker ids used for distributed execution" + workers::Vector{Int} + function DualTVDDSurrogateAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, τ=2*normB(problem), subalg=x->ChambolleAlgorithm(x), workers=workers()) + if parallel == true && σ > 1/4 + @warn "parallel domain decomposition needs σ >= 1/4 for theoretical convergence" + end + return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, τ, ninner, subalg, workers) + end +end + +struct DualTVDDSurrogateState{A,d,V,SAx,SC} + algorithm::A + + "global variable" + p::V + "subdomain axes wrt global indices" + subax::SAx + "connectors to subworkers" + cons::Array{SC,d} +end + +function init(alg::DualTVDDSurrogateAlgorithm{<:DualTVL1ROFOpProblem}) + g = alg.problem.g + d = ndims(g) + ax = axes(g) + + # subdomain axes + subax = subaxes(size(g), alg.M, alg.overlap) + # preallocated data for subproblems + subg = [Array{eltype(g), d}(undef, length.(x)) for x in subax] + # locally dependent tv parameter + subλ = [alg.problem.λ[subax[i]...] .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) + for i in CartesianIndices(subax)] + + # TODO: generalize to SArray + p1type(T::Type{<:Real}) = SVector{d,T} + p1type(::Type{SVector{m,T}}) where {m, T} = SMatrix{m,d,T,m*d} + + # global dual variable + p = zeros(p1type(eltype(g)), ax) + # local dual variable + subp = [zeros(p1type(eltype(g)), length.(sax)) for sax in subax] + + # create subproblem contexts + cids = alg.parallel ? [eachindex(subax)] : + chessboard_coloring(size(subax)) + cons = Array{Connector, d}(undef, size(subax)) + for (color, sidxs) in enumerate(cids) + for (i, sidx) in enumerate(sidxs) + sax = subax[sidx] + + subprob = DualTVL1ROFOpProblem(subg[sidx], I, subλ[sidx]) + wf = subworker(alg, alg.subalg(subprob)) + wid = alg.workers[mod1(i, length(alg.workers))] + cons[sidx] = outsource(wf, wid) + end + end + + return DualTVDDSurrogateState(alg, p, subax, cons) +end + +function step!(ctx::DualTVDDSurrogateState) + alg = ctx.algorithm + σ = ctx.algorithm.σ + d = ndims(ctx.p) + ax = axes(ctx.p) + overlap = ctx.algorithm.overlap + g = alg.problem.g + + p_rem = copy(ctx.p) + p_don = zeros(eltype(ctx.p), size(ctx.p)) + + # subdomain loop (in coloring order) + #cids = alg.parallel ? [eachindex(ctx.subax)] : + #chessboard_coloring(size(ctx.subax)) + cids = chessboard_coloring(size(ctx.subax)) + + for (color, ids) in enumerate(cids) + p_current = extend(zeros(eltype(ctx.p), size(ctx.p)), + StaticKernels.ExtensionNothing()) + for i in ids + sax = ctx.subax[i] + p_current .+= + theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ax)) .* ctx.p + end + p_rem .-= p_current + + @inline kfΛ(p) = @inbounds -divergence(p) + kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) + + if alg.parallel + p_other = ctx.p .- p_current + else + p_other = p_don .+ p_rem + end + p_other = extend(p_other, StaticKernels.ExtensionNothing()) + + r_current = map(kΛ, p_current) + r_other = map(kΛ, p_other) + + s = r_current .+ r_other .- g + + g_current = r_current + mul!(vec(g_current), alg.problem.B, vec(s), -1. / alg.τ, 1.) + + for pids in Iterators.partition(ids, length(alg.workers)) + for i in pids + sax = ctx.subax[i] + put!(ctx.cons[i], view(g_current, sax...)) + end + end + + p_don .+= (1 .- σ) .* p_current + for pids in Iterators.partition(ids, length(alg.workers)) + for i in pids + sax = ctx.subax[i] + p_i = take!(ctx.cons[i]) + view(p_don, sax...) .+= σ .* p_i + end + end + end + + ctx.p .= p_don + + + return ctx +end + +fetch(ctx::DualTVDDSurrogateState) = ctx.p diff --git a/src/surrogate.jl b/src/surrogate.jl index a4bcb48..b3ee405 100644 --- a/src/surrogate.jl +++ b/src/surrogate.jl @@ -8,7 +8,7 @@ struct SurrogateAlgorithm{P} <: Algorithm{P} τ::Float64 function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem; ninner=1, - τ=2*normB(problem), subalg=x->ProjGradAlgorithm(x)) + τ=2*normB(problem), subalg=x->ChambolleAlgorithm(x)) return new{typeof(problem)}(problem, subalg, ninner, τ) end end diff --git a/test/runtests.jl b/test/runtests.jl index 3a0b97e..ff1eb51 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Test, BenchmarkTools using LinearAlgebra using DualTVDD: DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, DualTVDDAlgorithm, - init, step!, fetch_u + init, step!, fetch_u, normB @testset "B = I" begin g = Float64[0 2; 1 0] @@ -57,3 +57,36 @@ end end @test fetch_u(st) ≈ fetch_u(stref) end + +@testset "DualTVDDSurrogateAlgorithm" begin + d = 2 + n = 5 + ninner = 100 + g = rand(n, n) + #B = Diagonal(rand(n^2)) + B = rand(n^2, n^2) ./ n^2 + B = I + 0.5 * (B + B') + M = ntuple(_->1, d) + overlap = ntuple(_->2, d) + parallel = true + σ = 1.0 + # big λ is ok, since we test for inter-subdomain communication + prob = DualTVL1ROFOpProblem(g, B, 1000.) + + algref = ChambolleAlgorithm(prob) + #algref = DualTVDDAlgorithm(prob; M, overlap, ninner=1, parallel, σ, + # subalg = x -> SurrogateAlgorithm(x; τ=2*normB(prob), ninner)) + alg = DualTVDDSurrogateAlgorithm(prob; M, overlap, ninner, parallel, σ, + subalg = x -> ChambolleAlgorithm(x)) + + stref = init(algref) + st = init(alg) + #@test 0 == @ballocated step!($ctx) + for i in 1:1000 + step!(stref) + end + for i in 1:1000 + step!(st) + end + @test fetch_u(st) ≈ fetch_u(stref) +end -- GitLab