diff --git a/src/chambolle.jl b/src/chambolle.jl index 0c41a784ec637a61590adc6ea9f9d5ce303d514e..0a2cfa03bc36f021bba97aac39fa18cc9b6082ac 100644 --- a/src/chambolle.jl +++ b/src/chambolle.jl @@ -38,11 +38,11 @@ struct ChambolleContext{M,A,G,Λ,T,R,S,Sv,K1,K2} <: Context s::S "dual variable as vector" rv::Sv - "scalar temporary" + "scalar temporary variable" sv::Sv "div(p) + g kernel" k1::K1 - "(p + τ*grad(q))/(1 + τ/λ|grad(q)|) kernel" + "(p + (τ/normB)*grad(q))/(1 + (τ/normB)/λ|grad(q)|) kernel" k2::K2 end @@ -61,9 +61,11 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm) @inline kf1(pw, gw) = @inbounds divergence(pw) + gw[z] k1 = Kernel{ntuple(_->-1:1, d)}(kf1) + normB = norm(md.B) + @inline function kf2(pw, sw) - iszero(λ[pw.position]) && return zero(pw[z]) - sgrad = alg.τ * gradient(sw) + @inbounds iszero(λ[pw.position]) && return zero(pw[z]) + sgrad = (alg.τ / normB) * gradient(sw) return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / λ[pw.position]) end k2 = Kernel{ntuple(_->0:1, d)}(kf2) @@ -111,6 +113,8 @@ function recover_u!(ctx) map!(ctx.k1, ctx.r, ctx.p, ctx.g) # s = B * r mul!(ctx.sv, ctx.model.B, ctx.rv) + + return ctx.s end #