Skip to content
Snippets Groups Projects
Commit c33b9e16 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

finish up fully distributed dd implementation

parent e20c2fcc
Branches
Tags
No related merge requests found
...@@ -11,6 +11,14 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" ...@@ -11,6 +11,14 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
deps = ["Libdl"] deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Outsource]]
deps = ["Distributed"]
git-tree-sha1 = "9303d4dd03e26e32e8ca0f87d39dfdefc7be27f2"
repo-rev = "master"
repo-url = "/home/stev47/stuff/Outsource"
uuid = "ce4b2b2b-baef-434e-9229-2c3161aca78b"
version = "0.1.0"
[[Random]] [[Random]]
deps = ["Serialization"] deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
...@@ -27,14 +35,14 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" ...@@ -27,14 +35,14 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[StaticArrays]] [[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"] deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0" git-tree-sha1 = "da4cf579416c81994afd6322365d00916c79b8ae"
uuid = "90137ffa-7385-5640-81b9-e52037218182" uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.12.3" version = "0.12.5"
[[StaticKernels]] [[StaticKernels]]
git-tree-sha1 = "d8d5f496e25ff848afc94da260223b9374dd06db" git-tree-sha1 = "84a49458d75b4a64850a71b0bf364cd94ffd4aae"
uuid = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7" uuid = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7"
version = "0.5.0" version = "0.5.1"
[[Statistics]] [[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"] deps = ["LinearAlgebra", "SparseArrays"]
......
...@@ -6,6 +6,7 @@ version = "0.1.0" ...@@ -6,6 +6,7 @@ version = "0.1.0"
[deps] [deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Outsource = "ce4b2b2b-baef-434e-9229-2c3161aca78b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticKernels = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7" StaticKernels = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7"
......
using Distributed: @everywhere, pmap using Distributed: nworkers, workers
using Outsource: Connector, outsource
#import Serialization
## We override SubArray serialization in order to preserve array data
## references. TODO: this should actually be julia-default!
#function Serialization.serialize(
# s::Serialization.AbstractSerializer,
# a::SubArray{T,N,A}) where {T,N,A<:Array}
#
# Serialization.serialize_any(s, a)
#end
struct DualTVDDAlgorithm{P,d} <: Algorithm{P} struct DualTVDDAlgorithm{P,d} <: Algorithm{P}
problem::P problem::P
...@@ -15,22 +27,23 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P} ...@@ -15,22 +27,23 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P}
ninner::Int ninner::Int
"prob -> Algorithm(::Problem, ...)" "prob -> Algorithm(::Problem, ...)"
subalg::Function subalg::Function
function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=1/4, ninner=10, subalg=x->ProjectedGradient(x)) function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, subalg=x->ProjGradAlgorithm(x))
if parallel == true && σ > 1/4
@warn "parallel domain decomposition needs σ >= 1/4 for theoretical convergence"
end
return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg) return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg)
end end
end end
struct DualTVDDState{A,d,V,SV,SAx,SC} struct DualTVDDState{A,d,V,SAx,SC}
algorithm::A algorithm::A
"global variable" "global variable"
p::V p::V
"local buffer" # TODO: get rid of this
q::Array{SV,d}
"subdomain axes wrt global indices" "subdomain axes wrt global indices"
subax::SAx subax::SAx
"context for subproblems" "connectors to subworkers"
subctx::Array{SC,d} cons::Array{SC,d}
end end
function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
...@@ -48,19 +61,24 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) ...@@ -48,19 +61,24 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
# global dual variable # global dual variable
p = zeros(SVector{d,eltype(g)}, size(g)) p = zeros(SVector{d,eltype(g)}, size(g))
# local buffer variables # local dual variable
q = [extend(zeros(SVector{d,eltype(g)}, length.(x)), StaticKernels.ExtensionNothing()) for x in subax] subp = [collect(reinterpret(Float64, zeros(SVector{d,eltype(g)}, prod(length.(x))))) for x in subax]
# create subproblem contexts # create subproblem contexts
li = LinearIndices(ax) cids = chessboard_coloring(size(subax))
subprobs = [DualTVL1ROFOpProblem(subg[i], op_restrict(alg.problem.B, ax, subax[i]), subλ[i]) cons = Array{Connector, d}(undef, size(subax))
for i in CartesianIndices(subax)] for (color, sidxs) in enumerate(cids)
for (i, sidx) in enumerate(sidxs)
subalg = [alg.subalg(subprobs[i]) for i in CartesianIndices(subax)] sax = subax[sidx]
subctx = [init(x) for x in subalg] subprob = DualTVL1ROFOpProblem(subg[sidx], op_restrict(alg.problem.B, ax, subax[sidx]), subλ[sidx])
wf = subworker(alg, alg.subalg(subprob))
wid = workers()[mod1(i, nworkers())]
cons[sidx] = outsource(wf, wid)
end
end
return DualTVDDState(alg, p, q, subax, subctx) return DualTVDDState(alg, p, subax, cons)
end end
function intersectin(a, b) function intersectin(a, b)
...@@ -71,8 +89,8 @@ function intersectin(a, b) ...@@ -71,8 +89,8 @@ function intersectin(a, b)
end end
function chessboard_coloring(sz) function chessboard_coloring(sz)
binli = LinearIndices((2, 2)) binli = LinearIndices(ntuple(_->2, length(sz)))
coloring = [Int[] for _ in 1:4] coloring = [Int[] for _ in 1:2^length(sz)]
li = LinearIndices(sz) li = LinearIndices(sz)
for I in CartesianIndices(sz) for I in CartesianIndices(sz)
...@@ -82,14 +100,25 @@ function chessboard_coloring(sz) ...@@ -82,14 +100,25 @@ function chessboard_coloring(sz)
return coloring return coloring
end end
function subrun!(subctx, maxiters) function subworker(alg, subalg)
#fetch(subctx) .= Ref(zero(eltype(fetch(subctx))) .+ 1) #fetch(st) .= reshape(reinterpret(SVector{d,eltype(g)}, initdata), size(g))
display("uiae")
step!(subctx) ninner = alg.ninner
#for j in 1:maxiters return function(con)
# step!(subctx) subst = init(subalg)
#end while isopen(con)
return subctx # fetch new data
subg = take!(con)
subalg.problem.g .= subg
# run algorithm
for _ in 1:ninner
step!(subst)
end
# write result
subp = fetch(subst)
put!(con, subp)
end
end
end end
function step!(ctx::DualTVDDState) function step!(ctx::DualTVDDState)
...@@ -99,57 +128,49 @@ function step!(ctx::DualTVDDState) ...@@ -99,57 +128,49 @@ function step!(ctx::DualTVDDState)
ax = axes(ctx.p) ax = axes(ctx.p)
overlap = ctx.algorithm.overlap overlap = ctx.algorithm.overlap
# call run! on each cell (this can be threaded)
p_rem = copy(ctx.p)
p_don = zeros(eltype(ctx.p), size(ctx.p))
# subdomain loop (in coloring order)
cids = chessboard_coloring(size(ctx.subax)) cids = chessboard_coloring(size(ctx.subax))
for (color, ids) in enumerate(cids) for (color, ids) in enumerate(cids)
# prepare data g for subproblems
for i in ids for i in ids
sax = ctx.subax[i] sax = ctx.subax[i]
li = LinearIndices(ctx.subax)[i]
sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround
sq = ctx.q[i] # julia-bug workaround
sg .= view(alg.problem.g, sax...) # update remaining old contribution
view(p_rem, sax...) .-=
theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax)) .* view(ctx.p, sax...)
# data: g - Λ(p_don + p_rem)
sg = copy(view(alg.problem.g, sax...))
sp = extend(similar(ctx.p, length.(sax)), StaticKernels.ExtensionNothing())
if alg.parallel if alg.parallel
sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...) sp .= (1. .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...)
else else
sq .= Ref(zero(eltype(sq))) sp .= view(p_don, sax...) .+ view(p_rem, sax...)
# contributions from previous domains
for (pcolor, pids) in enumerate(cids)
# TODO: only adjacent ones needed
for lj in pids
saxj = ctx.subax[lj]
pids, pidsi, pidsj = intersectin(CartesianIndices(sax), CartesianIndices(saxj))
if pcolor < color
sq[pidsi] .+= view(ctx.subctx[lj].p, pidsj)
elseif pcolor > color
sq[pidsi] .+= theta.(Ref(ax), Ref(saxj), Ref(overlap), pids) .* view(ctx.p, pids)
end
end
end
end end
@inline kfΛ(pw) = @inbounds sg[pw.position] + divergence(pw) @inline kf(spw) = @inbounds sg[spw.position] + divergence(spw)
= Kernel{ntuple(_->-1:1, d)}(kfΛ) kern = Kernel{ntuple(_->-1:1, d)}(kf)
map!(kern, sg, sp)
map!(, sg, sq) # start computation
put!(ctx.cons[i], sg)
end end
for i in ids
sax = ctx.subax[i]
# actually run subalgorithms sp = take!(ctx.cons[i])
ctx.subctx[ids] .= map(subrun!, deepcopy(ctx.subctx[ids]), [1 for _ in ids]) # reshape(reinterpret(SVector{d,eltype(alg.problem.g)}, ctx.subp[i]), size(ctx.subalg[i].problem.g))
#ctx.subctx[ids] .= map(subrun!, deepcopy(ctx.subctx[ids]), [alg.ninner for _ in ids]) # weighted update for new contribution
view(p_don, sax...) .+=
#for i in ids (1 .- σ) .* theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax)) .* view(ctx.p, sax...) .+ σ .* sp
# subrun!(ctx.subctx[i]) end
#end
end end
# aggregate (not thread-safe!) ctx.p .= p_don
ctx.p .*= 1 - σ
for (i, sax) in pairs(ctx.subax)
view(ctx.p, sax...) .+= σ .* ctx.subctx[i].p
end
return ctx return ctx
end end
......
using Test, BenchmarkTools using Test, BenchmarkTools
using LinearAlgebra using LinearAlgebra
using DualTVDD: using DualTVDD:
DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, DualTVDDAlgorithm,
init, step!, fetch, recover_u init, step!, fetch_u
@testset "B = I" begin @testset "B = I" begin
g = Float64[0 2; 1 0] g = Float64[0 2; 1 0]
...@@ -10,11 +10,11 @@ using DualTVDD: ...@@ -10,11 +10,11 @@ using DualTVDD:
@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob)) @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
ctx = init(alg) ctx = init(alg)
@test 0 == @ballocated step!($ctx) #@test 0 == @ballocated step!($ctx)
for i in 1:100 for i in 1:100
step!(ctx) step!(ctx)
end end
u = recover_u(fetch(ctx), ctx.algorithm.problem) u = fetch_u(ctx)
@test u g @test u g
end end
end end
...@@ -26,11 +26,34 @@ end ...@@ -26,11 +26,34 @@ end
@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob)) @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
ctx = init(alg) ctx = init(alg)
@test 0 == @ballocated step!($ctx) #@test 0 == @ballocated step!($ctx)
for i in 1:100 for i in 1:100
step!(ctx) step!(ctx)
end end
u = recover_u(fetch(ctx), ctx.algorithm.problem) u = fetch_u(ctx)
@test vec(u) B * vec(g) @test vec(u) B * vec(g)
end end
end end
@testset "DualTVDDAlgorithm" begin
n = 5
ninner = 100
g = rand(n, n)
B = Diagonal(rand(n^2))
# big λ is ok, since we test for inter-subdomain communication
prob = DualTVL1ROFOpProblem(g, B, 100.)
algref = ChambolleAlgorithm(prob)
alg = DualTVDDAlgorithm(prob; M=(2,2), overlap=(2,2), ninner, parallel = false, σ = 0.25, subalg = x -> ChambolleAlgorithm(x))
stref = init(algref)
st = init(alg)
#@test 0 == @ballocated step!($ctx)
for i in 1:1000*ninner
step!(stref)
end
for i in 1:1000
step!(st)
end
@test fetch_u(st) fetch_u(stref)
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment