Skip to content
Snippets Groups Projects
Commit 18f638d7 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

make projgrad build upon OpROFModel

parent de855f9a
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,8 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2}
p::V
"precomputed A'f"
g::W
"extended model.λ"
λ::W
"scalar temporary 1"
rv::Wv
......@@ -30,19 +32,16 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2}
k2::K2
end
function init(md::DualTVDDModel, alg::ProjGradAlgorithm)
# FIXME: A is assumed square
d = ndims(md.f)
ax = axes(md.f)
function init(md::OpROFModel, alg::ProjGradAlgorithm)
d = ndims(md.g)
ax = axes(md.g)
p = extend(zeros(SVector{d,Float64}, ax), StaticKernels.ExtensionNothing())
gtmp = reshape(md.A' * vec(md.f), size(md.f))
g = extend(gtmp, StaticKernels.ExtensionNothing())
g = extend(md.g, StaticKernels.ExtensionNothing())
λ = extend(md.λ, StaticKernels.ExtensionNothing())
rv = zeros(length(md.f))
sv = zeros(length(md.f))
B = inv(md.A' * md.A + md.β * I)
rv = zeros(length(md.g))
sv = zeros(length(md.g))
r = reshape(rv, ax)
s = extend(reshape(sv, ax), StaticKernels.ExtensionReplicate())
......@@ -54,11 +53,11 @@ function init(md::DualTVDDModel, alg::ProjGradAlgorithm)
@inline function kf2(pw, sw)
Base.@_inline_meta
q = pw[z] - alg.λ * gradient(sw)
return q / max(norm(q) / md.α, 1)
return q / max(norm(q) / md.λ[sw.position], 1)
end
k2 = Kernel{ntuple(_->0:1, d)}(kf2)
return ProjGradContext(md, alg, p, g, rv, sv, B, r, s, k1, k2)
return ProjGradContext(md, alg, p, g, λ, rv, sv, md.B, r, s, k1, k2)
end
function step!(ctx::ProjGradContext)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment