diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index af27db1891d5deb99efbb610f5be9bc69e3e05d3..4fcafbd60c766079504ab15c871ce4fed0cd361d 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -71,13 +71,13 @@ end
 function alg_error(alg, pmin, niter, ninner=1)
     print("run ...")
     res = Float64[]
-    (p, ctx) = iterate(alg)
-    push!(res, error(ctx.p, pmin, ctx.algorithm.problem))
+    ctx = init(alg)
+    push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem))
     for i in 1:niter
         for j in 1:ninner
-            (p, ctx) = iterate(alg, ctx)
+            step!(ctx)
         end
-        push!(res, error(ctx.p, pmin, ctx.algorithm.problem))
+        push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem))
     end
     println(" finished")
     return res
@@ -87,9 +87,9 @@ function calc_energy(prob, niter)
     alg_ref = DualTVDD.ChambolleAlgorithm(prob)
     ctx = init(alg_ref)
     for i in 1:niter
-        (p, ctx) = iterate(alg_ref, ctx)
+        ctx = step!(ctx)
     end
-    return energy(ctx), p
+    return energy(ctx), fetch(ctx)
 end
 
 function rundd()
@@ -132,8 +132,8 @@ function rundd()
          lognan.(alg_error(alg_dd2, pmin, n)),
        ]
 
-    plot(y, xaxis=:log, yaxis=:log)
-
+    plt = plot(y, xaxis=:log, yaxis=:log)
+    display(plt)
 
     #display(energy(ctx))
     #display(ctx.p)
diff --git a/src/chambolle.jl b/src/chambolle.jl
index 0a183cd803de5f663db841001fdb7c40c0ece8b9..148937967f4d2b48efa56b6eb5872561470cef3c 100644
--- a/src/chambolle.jl
+++ b/src/chambolle.jl
@@ -90,8 +90,7 @@ function step!(ctx::ChambolleState)
     return ctx
 end
 
-fetch(ctx::ChambolleState) =
-    recover_u(ctx.p, ctx.algorithm.problem)
+fetch(ctx::ChambolleState) = ctx.p
 
 function recover_u(p, md::DualTVL1ROFOpProblem)
     d = ndims(md.g)
@@ -102,7 +101,7 @@ function recover_u(p, md::DualTVL1ROFOpProblem)
     kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
 
     # v = div(p) + A'*f
-    map!(kΛ, v, extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
+    map!(kΛ, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
     v .+= md.g
     # u = B * v
     mul!(vec(u), md.B, vec(v))
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index b54df22cb5bd48975da28f4d564730ad32137783..f3925fff20f64b6949777a759208b2eebc6640a5 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -29,7 +29,7 @@ struct DualTVDDState{A,d,V,SV,SAx,SC}
     subctx::Array{SC,d}
 end
 
-function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
+function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
     g = alg.problem.g
     d = ndims(g)
     ax = axes(g)
@@ -56,9 +56,9 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
     #subalg = [ProjGradAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
     subalg = [ChambolleAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
 
-    subctx = [iterate(x)[2] for x in subalg]
+    subctx = [init(x) for x in subalg]
 
-    return p, DualTVDDState(alg, p, q, subax, subctx)
+    return DualTVDDState(alg, p, q, subax, subctx)
 end
 
 function intersectin(a, b)
@@ -68,7 +68,8 @@ function intersectin(a, b)
     return (c, c .- az, c .- bz)
 end
 
-function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
+function step!(ctx::DualTVDDState)
+    alg = ctx.algorithm
     # σ = 1 takes care of sequential updates
     σ = alg.parallel ? ctx.algorithm.σ : 1.
     d = ndims(ctx.p)
@@ -103,7 +104,7 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
         map!(kΛ, sg, sq)
 
         for j in 1:alg.ninner
-            (_, ctx.subctx[i]) = iterate(ctx.subctx[i].algorithm, ctx.subctx[i])
+            step!(ctx.subctx[i])
         end
     end
 
@@ -113,9 +114,11 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
         ctx.p[sax...] .+= σ .* ctx.subctx[i].p
     end
 
-    return ctx.p, ctx
+    return ctx
 end
 
+fetch(ctx::DualTVDDState) = ctx.p
+
 @generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
     i0 = ntuple(_->0, N)
     i1(k) = ntuple(i->Int(k==i), N)
diff --git a/src/projgrad.jl b/src/projgrad.jl
index 6bbc2ed80a07e34894c37c88c8813852798516e2..e45080d956906ebc8bbe8db57fa5af6f608f9321 100644
--- a/src/projgrad.jl
+++ b/src/projgrad.jl
@@ -70,4 +70,4 @@ function step!(ctx::ProjGradState)
     return ctx
 end
 
-fetch(ctx::ProjGradState) = recover_u(ctx.p, ctx.algorithm.problem)
+fetch(ctx::ProjGradState) = ctx.p
diff --git a/test/runtests.jl b/test/runtests.jl
index fa69c23dd3914e7c776f63fd9ff2e1f315db00fa..e2debb324629c3ee87f20ee63d8d7e694bf11aa9 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -2,17 +2,35 @@ using Test, BenchmarkTools
 using LinearAlgebra
 using DualTVDD:
     DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm,
-    init, step!, fetch
+    init, step!, fetch, recover_u
 
-g = Float64[0 2; 1 0]
-prob = DualTVL1ROFOpProblem(g, I, 1e-10)
+@testset "B = I" begin
+    g = Float64[0 2; 1 0]
+    prob = DualTVL1ROFOpProblem(g, I, 1e-10)
 
-@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
-    ctx = init(alg)
-    @test 0 == @ballocated step!($ctx)
-    for i in 1:100
-        step!(ctx)
+    @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
+        ctx = init(alg)
+        @test 0 == @ballocated step!($ctx)
+        for i in 1:100
+            step!(ctx)
+        end
+        u = recover_u(fetch(ctx), ctx.algorithm.problem)
+        @test u ≈ g
+    end
+end
+
+@testset "B = rand(...)" begin
+    g = Float64[0 2; 1 0]
+    B = rand(length(g), length(g))
+    prob = DualTVL1ROFOpProblem(g, B, 1e-10)
+
+    @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
+        ctx = init(alg)
+        @test 0 == @ballocated step!($ctx)
+        for i in 1:100
+            step!(ctx)
+        end
+        u = recover_u(fetch(ctx), ctx.algorithm.problem)
+        @test vec(u) ≈ B * vec(g)
     end
-    u = fetch(ctx)
-    @test u ≈ g
 end