diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl index c501566d45bb0ddeaadcc66b9ae2b38a63276aeb..1bb325e83b682da4c50c3ebd456750ebb0eb71ff 100644 --- a/src/dualtvdd.jl +++ b/src/dualtvdd.jl @@ -143,37 +143,39 @@ function step!(ctx::DualTVDDState) cids = alg.parallel ? [eachindex(ctx.subax)] : chessboard_coloring(size(ctx.subax)) for (color, ids) in enumerate(cids) - for i in ids - sax = ctx.subax[i] - - # update remaining old contribution - view(p_rem, sax...) .-= - theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax)) .* view(ctx.p, sax...) - - # data: g - Λ(p_don + p_rem) - sg = copy(view(alg.problem.g, sax...)) - sp = extend(similar(ctx.p, length.(sax)), StaticKernels.ExtensionNothing()) - if alg.parallel - sp .= (1. .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...) - else - sp .= view(p_don, sax...) .+ view(p_rem, sax...) + for pids in Iterators.partition(ids, length(alg.workers)) + for i in pids + sax = ctx.subax[i] + + # update remaining old contribution + view(p_rem, sax...) .-= + theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax)) .* view(ctx.p, sax...) + + # data: g - Λ(p_don + p_rem) + sg = copy(view(alg.problem.g, sax...)) + sp = extend(similar(ctx.p, length.(sax)), StaticKernels.ExtensionNothing()) + if alg.parallel + sp .= (1. .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...) + else + sp .= view(p_don, sax...) .+ view(p_rem, sax...) + end + + @inline kf(spw) = @inbounds sg[spw.position] + divergence(spw) + kern = Kernel{ntuple(_->-1:1, d)}(kf) + map!(kern, sg, sp) + + # start computation + put!(ctx.cons[i], sg) + end + for i in pids + sax = ctx.subax[i] + + sp = take!(ctx.cons[i]) + # reshape(reinterpret(SVector{d,eltype(alg.problem.g)}, ctx.subp[i]), size(ctx.subalg[i].problem.g)) + # weighted update for new contribution + view(p_don, sax...) .+= + (1 .- σ) .* theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax)) .* view(ctx.p, sax...) .+ σ .* sp end - - @inline kf(spw) = @inbounds sg[spw.position] + divergence(spw) - kern = Kernel{ntuple(_->-1:1, d)}(kf) - map!(kern, sg, sp) - - # start computation - put!(ctx.cons[i], sg) - end - for i in ids - sax = ctx.subax[i] - - sp = take!(ctx.cons[i]) - # reshape(reinterpret(SVector{d,eltype(alg.problem.g)}, ctx.subp[i]), size(ctx.subalg[i].problem.g)) - # weighted update for new contribution - view(p_don, sax...) .+= - (1 .- σ) .* theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax)) .* view(ctx.p, sax...) .+ σ .* sp end end