Skip to content
Snippets Groups Projects
Commit 37957857 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

implement global surrogate dd

parent 8e5453b7
No related branches found
No related tags found
No related merge requests found
module DualTVDD module DualTVDD
export DualTVDDAlgorithm export DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm,
SurrogateAlgorithm, DualTVDDAlgorithm, DualTVDDSurrogateAlgorithm
export init, step!, fetch, run export init, step!, fetch, run
export fetch_u, normB
include("common.jl") include("common.jl")
include("problems.jl") include("problems.jl")
include("nstep.jl") include("nstep.jl")
include("chambolle.jl") include("chambolle.jl")
include("dualtvdd.jl")
include("projgrad.jl") include("projgrad.jl")
include("surrogate.jl") include("surrogate.jl")
include("dualtvdd.jl")
include("dualtvdd_surrogate.jl")
#include("tvnewton.jl") #include("tvnewton.jl")
end # module end # module
using Distributed: workers
using Outsource: Connector, outsource
struct DualTVDDSurrogateAlgorithm{P,d} <: Algorithm{P}
problem::P
"number of subdomains in each dimension"
M::NTuple{d,Int}
"overlap in pixels per dimension"
overlap::NTuple{d,Int}
"use the non-sequential inertia update"
parallel::Bool
"inertia parameter (only when parallel)"
σ::Float64
"surrogate stepsize"
τ::Float64
"number of inner iterations"
ninner::Int
"prob -> Algorithm(::Problem, ...)"
subalg::Function
"worker ids used for distributed execution"
workers::Vector{Int}
function DualTVDDSurrogateAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, τ=2*normB(problem), subalg=x->ChambolleAlgorithm(x), workers=workers())
if parallel == true && σ > 1/4
@warn "parallel domain decomposition needs σ >= 1/4 for theoretical convergence"
end
return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, τ, ninner, subalg, workers)
end
end
struct DualTVDDSurrogateState{A,d,V,SAx,SC}
algorithm::A
"global variable"
p::V
"subdomain axes wrt global indices"
subax::SAx
"connectors to subworkers"
cons::Array{SC,d}
end
function init(alg::DualTVDDSurrogateAlgorithm{<:DualTVL1ROFOpProblem})
g = alg.problem.g
d = ndims(g)
ax = axes(g)
# subdomain axes
subax = subaxes(size(g), alg.M, alg.overlap)
# preallocated data for subproblems
subg = [Array{eltype(g), d}(undef, length.(x)) for x in subax]
# locally dependent tv parameter
subλ = [alg.problem.λ[subax[i]...] .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i]))
for i in CartesianIndices(subax)]
# TODO: generalize to SArray
p1type(T::Type{<:Real}) = SVector{d,T}
p1type(::Type{SVector{m,T}}) where {m, T} = SMatrix{m,d,T,m*d}
# global dual variable
p = zeros(p1type(eltype(g)), ax)
# local dual variable
subp = [zeros(p1type(eltype(g)), length.(sax)) for sax in subax]
# create subproblem contexts
cids = alg.parallel ? [eachindex(subax)] :
chessboard_coloring(size(subax))
cons = Array{Connector, d}(undef, size(subax))
for (color, sidxs) in enumerate(cids)
for (i, sidx) in enumerate(sidxs)
sax = subax[sidx]
subprob = DualTVL1ROFOpProblem(subg[sidx], I, subλ[sidx])
wf = subworker(alg, alg.subalg(subprob))
wid = alg.workers[mod1(i, length(alg.workers))]
cons[sidx] = outsource(wf, wid)
end
end
return DualTVDDSurrogateState(alg, p, subax, cons)
end
function step!(ctx::DualTVDDSurrogateState)
alg = ctx.algorithm
σ = ctx.algorithm.σ
d = ndims(ctx.p)
ax = axes(ctx.p)
overlap = ctx.algorithm.overlap
g = alg.problem.g
p_rem = copy(ctx.p)
p_don = zeros(eltype(ctx.p), size(ctx.p))
# subdomain loop (in coloring order)
#cids = alg.parallel ? [eachindex(ctx.subax)] :
#chessboard_coloring(size(ctx.subax))
cids = chessboard_coloring(size(ctx.subax))
for (color, ids) in enumerate(cids)
p_current = extend(zeros(eltype(ctx.p), size(ctx.p)),
StaticKernels.ExtensionNothing())
for i in ids
sax = ctx.subax[i]
p_current .+=
theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ax)) .* ctx.p
end
p_rem .-= p_current
@inline kfΛ(p) = @inbounds -divergence(p)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
if alg.parallel
p_other = ctx.p .- p_current
else
p_other = p_don .+ p_rem
end
p_other = extend(p_other, StaticKernels.ExtensionNothing())
r_current = map(, p_current)
r_other = map(, p_other)
s = r_current .+ r_other .- g
g_current = r_current
mul!(vec(g_current), alg.problem.B, vec(s), -1. / alg.τ, 1.)
for pids in Iterators.partition(ids, length(alg.workers))
for i in pids
sax = ctx.subax[i]
put!(ctx.cons[i], view(g_current, sax...))
end
end
p_don .+= (1 .- σ) .* p_current
for pids in Iterators.partition(ids, length(alg.workers))
for i in pids
sax = ctx.subax[i]
p_i = take!(ctx.cons[i])
view(p_don, sax...) .+= σ .* p_i
end
end
end
ctx.p .= p_don
return ctx
end
fetch(ctx::DualTVDDSurrogateState) = ctx.p
...@@ -8,7 +8,7 @@ struct SurrogateAlgorithm{P} <: Algorithm{P} ...@@ -8,7 +8,7 @@ struct SurrogateAlgorithm{P} <: Algorithm{P}
τ::Float64 τ::Float64
function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem; function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem;
ninner=1, ninner=1,
τ=2*normB(problem), subalg=x->ProjGradAlgorithm(x)) τ=2*normB(problem), subalg=x->ChambolleAlgorithm(x))
return new{typeof(problem)}(problem, subalg, ninner, τ) return new{typeof(problem)}(problem, subalg, ninner, τ)
end end
end end
......
...@@ -2,7 +2,7 @@ using Test, BenchmarkTools ...@@ -2,7 +2,7 @@ using Test, BenchmarkTools
using LinearAlgebra using LinearAlgebra
using DualTVDD: using DualTVDD:
DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, DualTVDDAlgorithm, DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, DualTVDDAlgorithm,
init, step!, fetch_u init, step!, fetch_u, normB
@testset "B = I" begin @testset "B = I" begin
g = Float64[0 2; 1 0] g = Float64[0 2; 1 0]
...@@ -57,3 +57,36 @@ end ...@@ -57,3 +57,36 @@ end
end end
@test fetch_u(st) fetch_u(stref) @test fetch_u(st) fetch_u(stref)
end end
@testset "DualTVDDSurrogateAlgorithm" begin
d = 2
n = 5
ninner = 100
g = rand(n, n)
#B = Diagonal(rand(n^2))
B = rand(n^2, n^2) ./ n^2
B = I + 0.5 * (B + B')
M = ntuple(_->1, d)
overlap = ntuple(_->2, d)
parallel = true
σ = 1.0
# big λ is ok, since we test for inter-subdomain communication
prob = DualTVL1ROFOpProblem(g, B, 1000.)
algref = ChambolleAlgorithm(prob)
#algref = DualTVDDAlgorithm(prob; M, overlap, ninner=1, parallel, σ,
# subalg = x -> SurrogateAlgorithm(x; τ=2*normB(prob), ninner))
alg = DualTVDDSurrogateAlgorithm(prob; M, overlap, ninner, parallel, σ,
subalg = x -> ChambolleAlgorithm(x))
stref = init(algref)
st = init(alg)
#@test 0 == @ballocated step!($ctx)
for i in 1:1000
step!(stref)
end
for i in 1:1000
step!(st)
end
@test fetch_u(st) fetch_u(stref)
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment