diff --git a/src/surrogate.jl b/src/surrogate.jl
index dcdacf911f7c7a1ccd95f75b7abaa4e72abe0181..9914b9995940e0edea5bf01bb6a797c870343e13 100644
--- a/src/surrogate.jl
+++ b/src/surrogate.jl
@@ -3,16 +3,18 @@ using LinearAlgebra: I
 struct SurrogateAlgorithm{P} <: Algorithm{P}
     problem::P
     subalg::Function
+    "number of inner iterations"
+    ninner::Int
     τ::Float64
     function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem;
-                                τ=2*normB(problem), subalg=x->ProjectedGradient(x))
-        return new{typeof(problem)}(problem, subalg, τ)
+                                ninner=1,
+                                τ=2*normB(problem), subalg=x->ProjGradAlgorithm(x))
+        return new{typeof(problem)}(problem, subalg, ninner, τ)
     end
 end
 
-struct SurrogateState{A,SP,SS,Wv,R,S,K1} <: State
+struct SurrogateState{A,SS,Wv,R,S,K1} <: State
     algorithm::A
-    subproblem::SP
     substate::SS
 
     "scalar temporary 1"
@@ -43,33 +45,45 @@ function init(alg::SurrogateAlgorithm)
     @inline kf1(pw) = @inbounds -divergence(pw)
     k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
 
-    subprob = DualTVL1ROFOpProblem(copy(prob.g), I, prob.λ)
-    substate = init(alg.subalg(subprob))
+    subg = copy(g)
+    subprob = DualTVL1ROFOpProblem(subg, I, prob.λ)
+    subalg = alg.subalg(subprob)
+    substate = init(subalg)
 
-    return SurrogateState(alg, subprob, substate, rv, sv, r, s, k1)
+    # we don't want to modify ourselves if alg.subalg gives us a bad reference
+    @assert subalg.problem.g !== g
+
+    return SurrogateState(alg, substate, rv, sv, r, s, k1)
 end
 
 function update_g!(subg, st::SurrogateState)
     alg = st.algorithm
     g = alg.problem.g
-    p = extend(fetch(st.substate), StaticKernels.ExtensionReplicate())
+
+    p = extend(fetch(st.substate), StaticKernels.ExtensionNothing())
+
+    sv = vec(parent(st.s))
+    rv = vec(st.r)
 
     # r = Λ*p
     map!(st.k1, st.r, p)
     # s = r - g
-    st.s .= st.r - g
-    # r = r + 1/τ * B * s
-    mul!(st.rv, alg.problem.B, st.sv, -1. / alg.τ, 1.)
+    st.s .= st.r .- g
+    # r = r - 1/τ * B * s
+    mul!(rv, alg.problem.B, sv, -1. / alg.τ, 1.)
 
     subg .= st.r
 end
 
 
 function step!(st::SurrogateState)
-    update_g!(st.subproblem.g, st)
+    subst = st.substate
+
+    # update data g of the subproblem
+    update_g!(subst.algorithm.problem.g, st)
 
-    for i in 1:10
-        step!(st.substate)
+    for i in 1:st.algorithm.ninner
+        step!(subst)
     end
 
     return st