Skip to content
Snippets Groups Projects
Commit e9420a5b authored by Stephan Hilb's avatar Stephan Hilb
Browse files

avoid nested SArrays

gets rid of annoying broadcasting memory allocations
parent c33b9e16
No related branches found
No related tags found
No related merge requests found
module DualTVDD
export DualTVDDAlgorithm
export init, step!, fetch, run
include("common.jl")
......
......@@ -49,22 +49,27 @@ struct ChambolleState{A,T,R,S,Sv,K1,K2} <: State
end
projnorm(v) = projnorm(v, false)
projnorm(v, anisotropic) = anisotropic ? abs.(v) : norm(vec(v))
const SVectorNest = SArray{<:Any,<:SArray{<:Any,Float64,1},1}
projnorm(v::SVector, anisotropic) = anisotropic ? abs.(v) : norm(vec(v))
#const SVectorNest = SArray{<:Any,<:SArray{<:Any,Float64,1},1}
# TODO: respect anisotropic
#@inline projnorm(v::SVectorNest, _) = norm(norm.(v))
@inline projnorm(v::SVectorNest, _) = sqrt(sum(sum(v[i] .^ 2) for i in eachindex(v)))
@inline projnorm(v, _) = norm(v)
#@inline projnorm(v::SVectorNest, _) = sqrt(sum(sum(v[i] .^ 2) for i in eachindex(v)))
function init(alg::ChambolleAlgorithm{<:DualTVL1ROFOpProblem})
g = alg.problem.g
λ = alg.problem.λ
d = ndims(g)
ax = axes(g)
# TODO: generalize to SArray
p1type(T::Type{<:Real}) = SVector{d,T}
p1type(::Type{SVector{m,T}}) where {m, T} = SMatrix{m,d,T,m*d}
pv = zeros(d * length(reinterpret(Float64, g)))
rv = zeros(eltype(g), length(g))
sv = zero(rv)
p = extend(reshape(reinterpret(SVector{d,eltype(g)}, pv), size(g)), StaticKernels.ExtensionNothing())
p = extend(zeros(p1type(eltype(g)), ax), StaticKernels.ExtensionNothing())
r = reshape(rv, size(g))
s = extend(reshape(sv, size(g)), StaticKernels.ExtensionReplicate())
......@@ -94,6 +99,11 @@ end
function step!(ctx::ChambolleState)
alg = ctx.algorithm
d = ndims(ctx.r)
kgrad = Kernel{ntuple(_->0:1, d)}(gradient)
λ = alg.problem.λ
p = ctx.p
sv = vec(parent(ctx.s))
rv = vec(ctx.r)
......@@ -103,7 +113,12 @@ function step!(ctx::ChambolleState)
# s = B * r
mul!(sv, alg.problem.B, rv)
# p = (p + τ*grad(s)) / (1 + τ/λ|grad(s)|)
map!(ctx.k2, ctx.p, ctx.s)
sgrad = map(kgrad, ctx.s)
f(p, sgrad, λ) = iszero(λ) ? zero(p) :
(p + alg.τ * sgrad) ./ (1. .+ projnorm(alg.τ * sgrad, alg.anisotropic) ./ λ)
ctx.p .= f.(ctx.p, sgrad, λ)
return ctx
end
......
......@@ -56,7 +56,7 @@ fetch_u(st) = recover_u(fetch(st), st.algorithm.problem)
Base.intersect(a::CartesianIndices{d}, b::CartesianIndices{d}) where d =
CartesianIndices(intersect.(a.indices, b.indices))
@generated function gradient(w::StaticKernels.Window{<:Any,N}) where N
@generated function gradient(w::StaticKernels.Window{<:Real,N}) where N
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
......@@ -67,6 +67,18 @@ Base.intersect(a::CartesianIndices{d}, b::CartesianIndices{d}) where d =
end
end
@generated function gradient(w::StaticKernels.Window{S,N}) where {S<:SArray, N}
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
wi = (:( w[$(i1(k)...)][$j] - w[$(i0...)][$j] ) for k in 1:N for j in 1:length(S))
return quote
Base.@_inline_meta
return @inbounds SArray{Tuple{$(size(S)...),N}}($(wi...))::
SArray{Tuple{$(size(S)...),N},$(eltype(S)),$(ndims(S)+1),$(length(S)*N)}
end
end
@generated function divergence(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
......@@ -79,6 +91,24 @@ end
end
end
@generated function divergence(w::StaticKernels.Window{S,N}) where {S<:SArray,N}
T = eltype(S)
sz = size(S)
sz[end] == N || throw(ArgumentError("last eltype dimension does not match array dimensionality"))
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
slice(k) = (ntuple(_->:, ndims(S)-1)..., k)
wi = (:((isnothing(w[$(i0...)]) ? zero($T) : w[$(i0...)][$(slice(k)...)]) -
(isnothing(w[$((.-i1(k))...)]) ? zero($T) : w[$((.-i1(k))...)][$(slice(k)...)])) for k in 1:N)
return quote
Base.@_inline_meta
return @inbounds +($(wi...))
end
end
function div_op(a::AbstractArray{<:StaticVector{N},N}) where N
k = Kernel{ntuple(_->-1:1, ndims(a))}(k)
ae = extend(a, StaticKernels.ExtensionNothing())
......
using Distributed: nworkers, workers
using Distributed: workers
using Outsource: Connector, outsource
#import Serialization
......@@ -27,11 +27,13 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P}
ninner::Int
"prob -> Algorithm(::Problem, ...)"
subalg::Function
function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, subalg=x->ProjGradAlgorithm(x))
"worker ids used for distributed execution"
workers::Vector{Int}
function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=parallel ? 1/4 : 1., ninner=10, subalg=x->ProjGradAlgorithm(x), workers=workers())
if parallel == true && σ > 1/4
@warn "parallel domain decomposition needs σ >= 1/4 for theoretical convergence"
end
return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg)
return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg, workers)
end
end
......@@ -59,10 +61,14 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
subλ = [alg.problem.λ[subax[i]...] .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i]))
for i in CartesianIndices(subax)]
# TODO: generalize to SArray
p1type(T::Type{<:Real}) = SVector{d,T}
p1type(::Type{SVector{m,T}}) where {m, T} = SMatrix{m,d,T,m*d}
# global dual variable
p = zeros(SVector{d,eltype(g)}, size(g))
p = zeros(p1type(eltype(g)), ax)
# local dual variable
subp = [collect(reinterpret(Float64, zeros(SVector{d,eltype(g)}, prod(length.(x))))) for x in subax]
subp = [zeros(p1type(eltype(g)), sax) for sax in subax]
# create subproblem contexts
cids = chessboard_coloring(size(subax))
......@@ -73,7 +79,7 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
subprob = DualTVL1ROFOpProblem(subg[sidx], op_restrict(alg.problem.B, ax, subax[sidx]), subλ[sidx])
wf = subworker(alg, alg.subalg(subprob))
wid = workers()[mod1(i, nworkers())]
wid = alg.workers[mod1(i, length(alg.workers))]
cons[sidx] = outsource(wf, wid)
end
end
......@@ -111,7 +117,7 @@ function subworker(alg, subalg)
subg = take!(con)
subalg.problem.g .= subg
# run algorithm
for _ in 1:ninner
for _ in 1:1000
step!(subst)
end
# write result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment