From 18f638d71902136d1750f6479420d341cb076554 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Sun, 21 Jun 2020 16:51:01 +0200
Subject: [PATCH] make projgrad build upon OpROFModel

---
 src/projgrad.jl | 23 +++++++++++------------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/src/projgrad.jl b/src/projgrad.jl
index ec42dd5..af0d14d 100644
--- a/src/projgrad.jl
+++ b/src/projgrad.jl
@@ -13,6 +13,8 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2}
     p::V
     "precomputed A'f"
     g::W
+    "extended model.λ"
+    λ::W
 
     "scalar temporary 1"
     rv::Wv
@@ -30,19 +32,16 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2}
     k2::K2
 end
 
-function init(md::DualTVDDModel, alg::ProjGradAlgorithm)
-    # FIXME: A is assumed square
-
-    d = ndims(md.f)
-    ax = axes(md.f)
+function init(md::OpROFModel, alg::ProjGradAlgorithm)
+    d = ndims(md.g)
+    ax = axes(md.g)
 
     p = extend(zeros(SVector{d,Float64}, ax), StaticKernels.ExtensionNothing())
-    gtmp = reshape(md.A' * vec(md.f), size(md.f))
-    g = extend(gtmp, StaticKernels.ExtensionNothing())
+    g = extend(md.g, StaticKernels.ExtensionNothing())
+    λ = extend(md.λ, StaticKernels.ExtensionNothing())
 
-    rv = zeros(length(md.f))
-    sv = zeros(length(md.f))
-    B = inv(md.A' * md.A + md.β * I)
+    rv = zeros(length(md.g))
+    sv = zeros(length(md.g))
 
     r = reshape(rv, ax)
     s = extend(reshape(sv, ax), StaticKernels.ExtensionReplicate())
@@ -54,11 +53,11 @@ function init(md::DualTVDDModel, alg::ProjGradAlgorithm)
     @inline function kf2(pw, sw)
         Base.@_inline_meta
         q = pw[z] - alg.λ * gradient(sw)
-        return q / max(norm(q) / md.α, 1)
+        return q / max(norm(q) / md.λ[sw.position], 1)
     end
     k2 = Kernel{ntuple(_->0:1, d)}(kf2)
 
-    return ProjGradContext(md, alg, p, g, rv, sv, B, r, s, k1, k2)
+    return ProjGradContext(md, alg, p, g, λ, rv, sv, md.B, r, s, k1, k2)
 end
 
 function step!(ctx::ProjGradContext)
-- 
GitLab