From 18f638d71902136d1750f6479420d341cb076554 Mon Sep 17 00:00:00 2001 From: Stephan Hilb <stephan@ecshi.net> Date: Sun, 21 Jun 2020 16:51:01 +0200 Subject: [PATCH] make projgrad build upon OpROFModel --- src/projgrad.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/projgrad.jl b/src/projgrad.jl index ec42dd5..af0d14d 100644 --- a/src/projgrad.jl +++ b/src/projgrad.jl @@ -13,6 +13,8 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2} p::V "precomputed A'f" g::W + "extended model.λ" + λ::W "scalar temporary 1" rv::Wv @@ -30,19 +32,16 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2} k2::K2 end -function init(md::DualTVDDModel, alg::ProjGradAlgorithm) - # FIXME: A is assumed square - - d = ndims(md.f) - ax = axes(md.f) +function init(md::OpROFModel, alg::ProjGradAlgorithm) + d = ndims(md.g) + ax = axes(md.g) p = extend(zeros(SVector{d,Float64}, ax), StaticKernels.ExtensionNothing()) - gtmp = reshape(md.A' * vec(md.f), size(md.f)) - g = extend(gtmp, StaticKernels.ExtensionNothing()) + g = extend(md.g, StaticKernels.ExtensionNothing()) + λ = extend(md.λ, StaticKernels.ExtensionNothing()) - rv = zeros(length(md.f)) - sv = zeros(length(md.f)) - B = inv(md.A' * md.A + md.β * I) + rv = zeros(length(md.g)) + sv = zeros(length(md.g)) r = reshape(rv, ax) s = extend(reshape(sv, ax), StaticKernels.ExtensionReplicate()) @@ -54,11 +53,11 @@ function init(md::DualTVDDModel, alg::ProjGradAlgorithm) @inline function kf2(pw, sw) Base.@_inline_meta q = pw[z] - alg.λ * gradient(sw) - return q / max(norm(q) / md.α, 1) + return q / max(norm(q) / md.λ[sw.position], 1) end k2 = Kernel{ntuple(_->0:1, d)}(kf2) - return ProjGradContext(md, alg, p, g, rv, sv, B, r, s, k1, k2) + return ProjGradContext(md, alg, p, g, λ, rv, sv, md.B, r, s, k1, k2) end function step!(ctx::ProjGradContext) -- GitLab