diff --git a/src/projgrad.jl b/src/projgrad.jl index ec42dd565e20ed3a68f8f386e97d54cef584e1f0..af0d14d92dbdf7945e90108d990a8d7d8ee06f2b 100644 --- a/src/projgrad.jl +++ b/src/projgrad.jl @@ -13,6 +13,8 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2} p::V "precomputed A'f" g::W + "extended model.λ" + λ::W "scalar temporary 1" rv::Wv @@ -30,19 +32,16 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2} k2::K2 end -function init(md::DualTVDDModel, alg::ProjGradAlgorithm) - # FIXME: A is assumed square - - d = ndims(md.f) - ax = axes(md.f) +function init(md::OpROFModel, alg::ProjGradAlgorithm) + d = ndims(md.g) + ax = axes(md.g) p = extend(zeros(SVector{d,Float64}, ax), StaticKernels.ExtensionNothing()) - gtmp = reshape(md.A' * vec(md.f), size(md.f)) - g = extend(gtmp, StaticKernels.ExtensionNothing()) + g = extend(md.g, StaticKernels.ExtensionNothing()) + λ = extend(md.λ, StaticKernels.ExtensionNothing()) - rv = zeros(length(md.f)) - sv = zeros(length(md.f)) - B = inv(md.A' * md.A + md.β * I) + rv = zeros(length(md.g)) + sv = zeros(length(md.g)) r = reshape(rv, ax) s = extend(reshape(sv, ax), StaticKernels.ExtensionReplicate()) @@ -54,11 +53,11 @@ function init(md::DualTVDDModel, alg::ProjGradAlgorithm) @inline function kf2(pw, sw) Base.@_inline_meta q = pw[z] - alg.λ * gradient(sw) - return q / max(norm(q) / md.α, 1) + return q / max(norm(q) / md.λ[sw.position], 1) end k2 = Kernel{ntuple(_->0:1, d)}(kf2) - return ProjGradContext(md, alg, p, g, rv, sv, B, r, s, k1, k2) + return ProjGradContext(md, alg, p, g, λ, rv, sv, md.B, r, s, k1, k2) end function step!(ctx::ProjGradContext)