diff --git a/src/chambolle.jl b/src/chambolle.jl index 370e7c4b7eeab2b714554b24cc8f1a372fb47a01..2844098cc2293831edae539770950d82e2cba872 100644 --- a/src/chambolle.jl +++ b/src/chambolle.jl @@ -4,8 +4,8 @@ # https://doi.org/10.1023/B:JMIV.0000011325.36760.1e # # Implementation Notes: -# - TV-parameter α instead of λ -# - feasibility constraint |p|<=α instead of |p|<=1 +# - λ is not a scalar but a scalar field +# - pointwise feasibility constraint |p|<=λ instead of |p|<=1 # - B is introduced, the original Chambolle algorithm has B=Id using Base: @_inline_meta @@ -21,13 +21,15 @@ struct ChambolleAlgorithm <: Algorithm end end -struct ChambolleContext{M,A,G,T,R,S,Sv,K1,K2} <: Context +struct ChambolleContext{M,A,G,Λ,T,R,S,Sv,K1,K2} <: Context "model data" model::M "algorithm data" algorithm::A "matrix view on model.f" g::G + "matrix view on model.λ" + λ::Λ "matrix view on pv" p::T "matrix view on rv" @@ -40,13 +42,14 @@ struct ChambolleContext{M,A,G,T,R,S,Sv,K1,K2} <: Context sv::Sv "div(p) + g kernel" k1::K1 - "(p + τ*grad(q))/(1 + τ/α|grad(q)|) kernel" + "(p + τ*grad(q))/(1 + τ/λ|grad(q)|) kernel" k2::K2 end -function init(md::ChambolleModel, alg::ChambolleAlgorithm) +function init(md::OpROFModel, alg::ChambolleAlgorithm) d = ndims(md.g) g = extend(md.g, StaticKernels.ExtensionNothing()) + λ = extend(md.λ, StaticKernels.ExtensionNothing()) pv = zeros(d * length(md.g)) rv = zeros(length(md.g)) sv = zero(rv) @@ -60,11 +63,11 @@ function init(md::ChambolleModel, alg::ChambolleAlgorithm) @inline function kf2(pw, sw) sgrad = alg.τ * gradient(sw) - return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / md.λ) + return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / λ[pw.position]) end k2 = Kernel{ntuple(_->0:1, d)}(kf2) - return ChambolleContext(md, alg, g, p, r, s, rv, sv, k1, k2) + return ChambolleContext(md, alg, g, λ, p, r, s, rv, sv, k1, k2) end @generated function gradient(w::StaticKernels.Window{<:Any,N}) where N @@ -98,7 +101,7 @@ function step!(ctx::ChambolleContext) map!(ctx.k1, ctx.r, ctx.p, ctx.g) # s = B * r mul!(ctx.sv, ctx.model.B, ctx.rv) - # p = (p + τ*grad(s)) / (1 + τ/α|grad(s)|) + # p = (p + τ*grad(s)) / (1 + τ/λ|grad(s)|) map!(ctx.k2, ctx.p, ctx.p, ctx.s) end diff --git a/src/types.jl b/src/types.jl index 2db8925a5760b8ff24f040e121c2f23352fcad7e..f80834cdd41fcc0a017d906319ae522b6642804b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -30,11 +30,13 @@ struct DualTVDDModel{U,VV} <: Model end "min_p 1/2 * |div(p) - g|_B^2 + χ_{|p|<=λ}" -struct ChambolleModel{U,VV} <: Model +struct OpROFModel{U,VV,Λ} <: Model "given data" g::U "B norm operator" B::VV "total variation parameter" - λ::Float64 + λ::Λ end + +OpROFModel(g, B, λ::Real) = OpROFModel(g, B, fill!(similar(g, typeof(λ)), λ))