diff --git a/src/chambolle.jl b/src/chambolle.jl index fd8d0b995edbbdd221ba3e87e52cd9a16a6bee5c..4cd2ae83745b78b496f82922263d20270a71e8e3 100644 --- a/src/chambolle.jl +++ b/src/chambolle.jl @@ -105,19 +105,3 @@ function step!(ctx::ChambolleState) end fetch(ctx::ChambolleState) = ctx.p - -function recover_u(p, md::DualTVL1ROFOpProblem) - d = ndims(md.g) - u = similar(md.g) - v = similar(md.g) - - @inline kfΛ(w) = @inbounds divergence(w) - kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) - - # v = div(p) + A'*f - map!(kΛ, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing - v .+= md.g - # u = B * v - mul!(vec(u), md.B, vec(v)) - return u -end diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl index 9745045ef5280168c4e7a357823fbfb82aa8e1d7..7de8732542c285316e715df943d3dcefc1c90975 100644 --- a/src/dualtvdd.jl +++ b/src/dualtvdd.jl @@ -70,14 +70,13 @@ end function step!(ctx::DualTVDDState) alg = ctx.algorithm - # σ = 1 takes care of sequential updates - σ = alg.parallel ? ctx.algorithm.σ : 1. + σ = ctx.algorithm.σ d = ndims(ctx.p) ax = axes(ctx.p) overlap = ctx.algorithm.overlap # call run! on each cell (this can be threaded) - Threads.@threads for i in eachindex(ctx.subax) + for i in eachindex(ctx.subax) sax = ctx.subax[i] li = LinearIndices(ctx.subax)[i] sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround diff --git a/src/problems.jl b/src/problems.jl index d37dfa30e721f84aae0154a1234f0b6ee8e27615..be32f99af92fe5f9c34f23425c8ea44e812bd59e 100644 --- a/src/problems.jl +++ b/src/problems.jl @@ -26,9 +26,36 @@ function energy(p, prob::DualTVL1ROFOpProblem) # v = div(p) + g v = map(kΛ, extend(p, ExtensionNothing())) - # |v|_B^2 / 2 + # |v|_B^2 u = prob.B * vec(v) - return sum(dot.(u, vec(v))) / 2 + return sum(dot.(u, vec(v))) +end + +function recover_u(p, prob::DualTVL1ROFOpProblem) + d = ndims(prob.g) + u = similar(prob.g) + v = similar(prob.g) + + @inline kfΛ(w) = @inbounds divergence(w) + kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) + + # v = div(p) + A'*f + map!(kΛ, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing + v .+= prob.g + # u = B * v + mul!(vec(u), prob.B, vec(v)) + return u +end + +function residual(p, prob::DualTVL1ROFOpProblem) + d = ndims(p) + grad = Kernel{ntuple(_->0:1, d)}(gradient) + + u = recover_u(p, prob) + q = map(grad, StaticKernels.extend(u, StaticKernels.ExtensionReplicate())) + + res = q .- norm.(q) .* p ./ prob.λ + return sum(dot.(res, res)) / length(p) end # operator norm of B