diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl index bea9bf33d6f87767efae582b5c7326284b2a3c8b..c501566d45bb0ddeaadcc66b9ae2b38a63276aeb 100644 --- a/src/dualtvdd.jl +++ b/src/dualtvdd.jl @@ -68,10 +68,11 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) # global dual variable p = zeros(p1type(eltype(g)), ax) # local dual variable - subp = [zeros(p1type(eltype(g)), sax) for sax in subax] + subp = [zeros(p1type(eltype(g)), length.(sax)) for sax in subax] # create subproblem contexts - cids = chessboard_coloring(size(subax)) + cids = alg.parallel ? [eachindex(subax)] : + chessboard_coloring(size(subax)) cons = Array{Connector, d}(undef, size(subax)) for (color, sidxs) in enumerate(cids) for (i, sidx) in enumerate(sidxs) @@ -117,7 +118,7 @@ function subworker(alg, subalg) subg = take!(con) subalg.problem.g .= subg # run algorithm - for _ in 1:1000 + for _ in 1:ninner step!(subst) end # write result @@ -139,7 +140,8 @@ function step!(ctx::DualTVDDState) p_don = zeros(eltype(ctx.p), size(ctx.p)) # subdomain loop (in coloring order) - cids = chessboard_coloring(size(ctx.subax)) + cids = alg.parallel ? [eachindex(ctx.subax)] : + chessboard_coloring(size(ctx.subax)) for (color, ids) in enumerate(cids) for i in ids sax = ctx.subax[i]