diff --git a/src/projgrad.jl b/src/projgrad.jl
index ec42dd565e20ed3a68f8f386e97d54cef584e1f0..af0d14d92dbdf7945e90108d990a8d7d8ee06f2b 100644
--- a/src/projgrad.jl
+++ b/src/projgrad.jl
@@ -13,6 +13,8 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2}
     p::V
     "precomputed A'f"
     g::W
+    "extended model.λ"
+    λ::W
 
     "scalar temporary 1"
     rv::Wv
@@ -30,19 +32,16 @@ struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2}
     k2::K2
 end
 
-function init(md::DualTVDDModel, alg::ProjGradAlgorithm)
-    # FIXME: A is assumed square
-
-    d = ndims(md.f)
-    ax = axes(md.f)
+function init(md::OpROFModel, alg::ProjGradAlgorithm)
+    d = ndims(md.g)
+    ax = axes(md.g)
 
     p = extend(zeros(SVector{d,Float64}, ax), StaticKernels.ExtensionNothing())
-    gtmp = reshape(md.A' * vec(md.f), size(md.f))
-    g = extend(gtmp, StaticKernels.ExtensionNothing())
+    g = extend(md.g, StaticKernels.ExtensionNothing())
+    λ = extend(md.λ, StaticKernels.ExtensionNothing())
 
-    rv = zeros(length(md.f))
-    sv = zeros(length(md.f))
-    B = inv(md.A' * md.A + md.β * I)
+    rv = zeros(length(md.g))
+    sv = zeros(length(md.g))
 
     r = reshape(rv, ax)
     s = extend(reshape(sv, ax), StaticKernels.ExtensionReplicate())
@@ -54,11 +53,11 @@ function init(md::DualTVDDModel, alg::ProjGradAlgorithm)
     @inline function kf2(pw, sw)
         Base.@_inline_meta
         q = pw[z] - alg.λ * gradient(sw)
-        return q / max(norm(q) / md.α, 1)
+        return q / max(norm(q) / md.λ[sw.position], 1)
     end
     k2 = Kernel{ntuple(_->0:1, d)}(kf2)
 
-    return ProjGradContext(md, alg, p, g, rv, sv, B, r, s, k1, k2)
+    return ProjGradContext(md, alg, p, g, λ, rv, sv, md.B, r, s, k1, k2)
 end
 
 function step!(ctx::ProjGradContext)