Skip to content
Snippets Groups Projects
Commit e9420a5b authored by Stephan Hilb's avatar Stephan Hilb
Browse files

avoid nested SArrays

gets rid of annoying broadcasting memory allocations
parent c33b9e16
No related branches found
No related tags found
No related merge requests found
module DualTVDD module DualTVDD
export DualTVDDAlgorithm
export init, step!, fetch, run export init, step!, fetch, run
include("common.jl") include("common.jl")
......
...@@ -49,22 +49,27 @@ struct ChambolleState{A,T,R,S,Sv,K1,K2} <: State ...@@ -49,22 +49,27 @@ struct ChambolleState{A,T,R,S,Sv,K1,K2} <: State
end end
projnorm(v) = projnorm(v, false) projnorm(v) = projnorm(v, false)
projnorm(v, anisotropic) = anisotropic ? abs.(v) : norm(vec(v)) projnorm(v::SVector, anisotropic) = anisotropic ? abs.(v) : norm(vec(v))
const SVectorNest = SArray{<:Any,<:SArray{<:Any,Float64,1},1} #const SVectorNest = SArray{<:Any,<:SArray{<:Any,Float64,1},1}
# TODO: respect anisotropic # TODO: respect anisotropic
#@inline projnorm(v::SVectorNest, _) = norm(norm.(v)) @inline projnorm(v, _) = norm(v)
@inline projnorm(v::SVectorNest, _) = sqrt(sum(sum(v[i] .^ 2) for i in eachindex(v))) #@inline projnorm(v::SVectorNest, _) = sqrt(sum(sum(v[i] .^ 2) for i in eachindex(v)))
function init(alg::ChambolleAlgorithm{<:DualTVL1ROFOpProblem}) function init(alg::ChambolleAlgorithm{<:DualTVL1ROFOpProblem})
g = alg.problem.g g = alg.problem.g
λ = alg.problem.λ λ = alg.problem.λ
d = ndims(g) d = ndims(g)
ax = axes(g)
# TODO: generalize to SArray
p1type(T::Type{<:Real}) = SVector{d,T}
p1type(::Type{SVector{m,T}}) where {m, T} = SMatrix{m,d,T,m*d}
pv = zeros(d * length(reinterpret(Float64, g)))
rv = zeros(eltype(g), length(g)) rv = zeros(eltype(g), length(g))
sv = zero(rv) sv = zero(rv)
p = extend(reshape(reinterpret(SVector{d,eltype(g)}, pv), size(g)), StaticKernels.ExtensionNothing()) p = extend(zeros(p1type(eltype(g)), ax), StaticKernels.ExtensionNothing())
r = reshape(rv, size(g)) r = reshape(rv, size(g))
s = extend(reshape(sv, size(g)), StaticKernels.ExtensionReplicate()) s = extend(reshape(sv, size(g)), StaticKernels.ExtensionReplicate())
...@@ -94,6 +99,11 @@ end ...@@ -94,6 +99,11 @@ end
function step!(ctx::ChambolleState) function step!(ctx::ChambolleState)
alg = ctx.algorithm alg = ctx.algorithm
d = ndims(ctx.r)
kgrad = Kernel{ntuple(_->0:1, d)}(gradient)
λ = alg.problem.λ
p = ctx.p
sv = vec(parent(ctx.s)) sv = vec(parent(ctx.s))
rv = vec(ctx.r) rv = vec(ctx.r)
...@@ -103,7 +113,12 @@ function step!(ctx::ChambolleState) ...@@ -103,7 +113,12 @@ function step!(ctx::ChambolleState)
# s = B * r # s = B * r
mul!(sv, alg.problem.B, rv) mul!(sv, alg.problem.B, rv)
# p = (p + τ*grad(s)) / (1 + τ/λ|grad(s)|) # p = (p + τ*grad(s)) / (1 + τ/λ|grad(s)|)
map!(ctx.k2, ctx.p, ctx.s)
sgrad = map(kgrad, ctx.s)
f(p, sgrad, λ) = iszero(λ) ? zero(p) :
(p + alg.τ * sgrad) ./ (1. .+ projnorm(alg.τ * sgrad, alg.anisotropic) ./ λ)
ctx.p .= f.(ctx.p, sgrad, λ)
return ctx return ctx
end end
......
...@@ -56,7 +56,7 @@ fetch_u(st) = recover_u(fetch(st), st.algorithm.problem) ...@@ -56,7 +56,7 @@ fetch_u(st) = recover_u(fetch(st), st.algorithm.problem)
Base.intersect(a::CartesianIndices{d}, b::CartesianIndices{d}) where d = Base.intersect(a::CartesianIndices{d}, b::CartesianIndices{d}) where d =
CartesianIndices(intersect.(a.indices, b.indices)) CartesianIndices(intersect.(a.indices, b.indices))
@generated function gradient(w::StaticKernels.Window{<:Any,N}) where N @generated function gradient(w::StaticKernels.Window{<:Real,N}) where N
i0 = ntuple(_->0, N) i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N) i1(k) = ntuple(i->Int(k==i), N)
...@@ -67,6 +67,18 @@ Base.intersect(a::CartesianIndices{d}, b::CartesianIndices{d}) where d = ...@@ -67,6 +67,18 @@ Base.intersect(a::CartesianIndices{d}, b::CartesianIndices{d}) where d =
end end
end end
@generated function gradient(w::StaticKernels.Window{S,N}) where {S<:SArray, N}
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
wi = (:( w[$(i1(k)...)][$j] - w[$(i0...)][$j] ) for k in 1:N for j in 1:length(S))
return quote
Base.@_inline_meta
return @inbounds SArray{Tuple{$(size(S)...),N}}($(wi...))::
SArray{Tuple{$(size(S)...),N},$(eltype(S)),$(ndims(S)+1),$(length(S)*N)}
end
end
@generated function divergence(w::StaticKernels.Window{SVector{N,T},N}) where {N,T} @generated function divergence(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
i0 = ntuple(_->0, N) i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N) i1(k) = ntuple(i->Int(k==i), N)
...@@ -79,6 +91,24 @@ end ...@@ -79,6 +91,24 @@ end
end end
end end
@generated function divergence(w::StaticKernels.Window{S,N}) where {S<:SArray,N}
T = eltype(S)
sz = size(S)
sz[end] == N || throw(ArgumentError("last eltype dimension does not match array dimensionality"))
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
slice(k) = (ntuple(_->:, ndims(S)-1)..., k)
wi = (:((isnothing(w[$(i0...)]) ? zero($T) : w[$(i0...)][$(slice(k)...)]) -
(isnothing(w[$((.-i1(k))...)]) ? zero($T) : w[$((.-i1(k))...)][$(slice(k)...)])) for k in 1:N)
return quote
Base.@_inline_meta
return @inbounds +($(wi...))
end
end
function div_op(a::AbstractArray{<:StaticVector{N},N}) where N function div_op(a::AbstractArray{<:StaticVector{N},N}) where N
k = Kernel{ntuple(_->-1:1, ndims(a))}(k) k = Kernel{ntuple(_->-1:1, ndims(a))}(k)
ae = extend(a, StaticKernels.ExtensionNothing()) ae = extend(a, StaticKernels.ExtensionNothing())
......
using Distributed: nworkers, workers using Distributed: workers
using Outsource: Connector, outsource using Outsource: Connector, outsource
#import Serialization #import Serialization
...@@ -27,11 +27,13 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P} ...@@ -27,11 +27,13 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P}
ninner::Int ninner::Int
"prob -> Algorithm(::Problem, ...)" "prob -> Algorithm(::Problem, ...)"
subalg::Function subalg::Function
function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, subalg=x->ProjGradAlgorithm(x)) "worker ids used for distributed execution"
workers::Vector{Int}
function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, subalg=x->ProjGradAlgorithm(x), workers=workers())
if parallel == true && σ > 1/4 if parallel == true && σ > 1/4
@warn "parallel domain decomposition needs σ >= 1/4 for theoretical convergence" @warn "parallel domain decomposition needs σ >= 1/4 for theoretical convergence"
end end
return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg) return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg, workers)
end end
end end
...@@ -59,10 +61,14 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) ...@@ -59,10 +61,14 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
subλ = [alg.problem.λ[subax[i]...] .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) subλ = [alg.problem.λ[subax[i]...] .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i]))
for i in CartesianIndices(subax)] for i in CartesianIndices(subax)]
# TODO: generalize to SArray
p1type(T::Type{<:Real}) = SVector{d,T}
p1type(::Type{SVector{m,T}}) where {m, T} = SMatrix{m,d,T,m*d}
# global dual variable # global dual variable
p = zeros(SVector{d,eltype(g)}, size(g)) p = zeros(p1type(eltype(g)), ax)
# local dual variable # local dual variable
subp = [collect(reinterpret(Float64, zeros(SVector{d,eltype(g)}, prod(length.(x))))) for x in subax] subp = [zeros(p1type(eltype(g)), sax) for sax in subax]
# create subproblem contexts # create subproblem contexts
cids = chessboard_coloring(size(subax)) cids = chessboard_coloring(size(subax))
...@@ -73,7 +79,7 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) ...@@ -73,7 +79,7 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
subprob = DualTVL1ROFOpProblem(subg[sidx], op_restrict(alg.problem.B, ax, subax[sidx]), subλ[sidx]) subprob = DualTVL1ROFOpProblem(subg[sidx], op_restrict(alg.problem.B, ax, subax[sidx]), subλ[sidx])
wf = subworker(alg, alg.subalg(subprob)) wf = subworker(alg, alg.subalg(subprob))
wid = workers()[mod1(i, nworkers())] wid = alg.workers[mod1(i, length(alg.workers))]
cons[sidx] = outsource(wf, wid) cons[sidx] = outsource(wf, wid)
end end
end end
...@@ -111,7 +117,7 @@ function subworker(alg, subalg) ...@@ -111,7 +117,7 @@ function subworker(alg, subalg)
subg = take!(con) subg = take!(con)
subalg.problem.g .= subg subalg.problem.g .= subg
# run algorithm # run algorithm
for _ in 1:ninner for _ in 1:1000
step!(subst) step!(subst)
end end
# write result # write result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment