diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index 4a7cf2461706e148e7f4d439e0216ef4a3b566f6..450991b1105ff96d81960d2441b4bce72169b2ff 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -1,5 +1,7 @@
 module DualTVDD
 
+
+
 include("types.jl")
 include("chambolle.jl")
 include("dualtvdd.jl")
@@ -8,44 +10,56 @@ include("dualtvdd.jl")
 using Makie: heatmap
 
 function run()
-    g = rand(50,50)
+    g = zeros(20,20)
+    g[4:17,4:17] .= 1
+    #g[:size(g, 1)÷2,:] .= 1
     #g = [0. 2; 1 0.]
-    A = diagm(ones(length(g)))
-    α = 0.25
+    B = diagm(fill(100, length(g)))
+    α = 0.1
 
-    md = DualTVDD.OpROFModel(g, A, α)
+    md = DualTVDD.OpROFModel(g, B, α)
     alg = DualTVDD.ChambolleAlgorithm()
     ctx = DualTVDD.init(md, alg)
 
-    scene = heatmap(ctx.s,
-        colorrange=(0,1), colormap=:gray, scale_plot=false)
-
-    display(scene)
+    #scene = vbox(
+    #    heatmap(ctx.s, colorrange=(0,1), colormap=:gray, scale_plot=false, show_axis=false),
+    #    heatmap(ctx.s, colorrange=(0,1), colormap=:gray, scale_plot=false, show_axis=false),
+    #   )
+    #display(scene)
 
-    hm = last(scene)
-    for i in 1:100
+    #hm = last(scene)
+    for i in 1:10000
         step!(ctx)
-        hm[1] = ctx.s
-        yield()
+        #hm[1] = ctx.s
+        #yield()
         #sleep(0.2)
     end
+    display(ctx.p)
+    display(recover_u!(ctx))
+
     ctx
     #hm[1] = ctx.s
     #yield()
 end
 
 function rundd()
+    β = 0
     f = zeros(8,8)
-    f[1,:] .= 1
+    f[1:4,:] .= 1
     #g = [0. 2; 1 0.]
-    A = diagm(ones(length(f)))
-    α = 0.25
+    A = diagm(vcat(fill(1/2, length(f)÷2), fill(2, length(f)÷2)))
+    B = inv(A'*A + β*I)
+
+    g = similar(f)
+    vec(g) .= A' * vec(f)
+
+    α = .25
 
     md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
     alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25)
     ctx = DualTVDD.init(md, alg)
 
-    md2 = DualTVDD.OpROFModel(f, A, α)
+    md2 = DualTVDD.OpROFModel(g, B, α)
     alg2 = DualTVDD.ChambolleAlgorithm()
     ctx2 = DualTVDD.init(md2, alg2)
 
@@ -81,5 +95,46 @@ function rundd()
     ctx, ctx2
 end
 
+function energy(ctx::DualTVDDContext)
+    d = ndims(ctx.p)
+
+    @inline kfΛ(w) = @inbounds divergence(w)
+    kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+    v = similar(ctx.g)
+
+    # v = div(p) + A'*f
+    map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
+    v .+= ctx.g
+
+    u = ctx.B * vec(v)
+
+    return sum(u .* vec(v)) / 2
+end
+
+function energy(ctx::ChambolleContext)
+    d = ndims(ctx.p)
+
+    @inline kfΛ(w) = @inbounds divergence(w)
+    kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+    v = similar(ctx.model.g)
+
+    # v = div(p) + g
+    map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
+    v .+= ctx.model.g
+
+    u = ctx.model.B * vec(v)
+
+    return sum(u .* vec(v)) / 2
+end
+
+
+function energy(md::DualTVDDModel, u::AbstractMatrix)
+    @inline kf(w) = @inbounds 1/2 * (w[0,0] - md.g[w.position])^2 +
+        md.λ * sqrt((w[1,0] - w[0,0])^2 + (w[0,1] - w[0,0])^2)
+    k = Kernel{(0:1, 0:1)}(kf, StaticKernels.ExtensionReplicate())
+    return sum(k, u)
+end
 
 end # module
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index df4a0dc23f78d5a604ff910dd6244c8b171219f8..bc18ff20802d83b1b133c9b33d2f70f9c8ffc363 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -10,19 +10,19 @@ struct DualTVDDAlgorithm{d} <: Algorithm
     end
 end
 
-struct DualTVDDContext{M,A,G,d,U,V,VV,SAx,Vview,SC}
+struct DualTVDDContext{M,A,G,d,U,V,Vtmp,VV,SAx,SC}
     model::M
     algorithm::A
     "precomputed A'f"
     g::G
     "global dual optimization variable"
     p::V
-    "(A'A + βI)^(-1)"
+    "global dual temporary variable"
+    ptmp::Vtmp
+    "precomputed (A'A + βI)^(-1)"
     B::VV
     "subdomain axes wrt global indices"
     subax::SAx
-    "local views on p per subdomain"
-    pviews::Array{Vview,d}
     "subproblem data, subg[i] == subctx[i].model.g"
     subg::Array{U,d}
     "context for subproblems"
@@ -32,36 +32,34 @@ end
 function init(md::DualTVDDModel, alg::DualTVDDAlgorithm)
     d = ndims(md.f)
     ax = axes(md.f)
+
     # subdomain axes
     subax = subaxes(md.f, alg.M, alg.overlap)
-
-    # data for subproblems
+    # preallocated data for subproblems
     subg = [Array{Float64, d}(undef, length.(subax[i])) for i in CartesianIndices(subax)]
-
     # locally dependent tv parameter
-    subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)]
+    subα = [md.α .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i]))
+        for i in CartesianIndices(subax)]
 
+    # this is the global g, the local gs are getting initialized in step!()
     g = reshape(md.A' * vec(md.f), size(md.f))
 
+    # global dual variables
     p = zeros(SVector{d,Float64}, size(md.f))
+    ptmp = extend(zeros(SVector{d,Float64}, size(md.f)), StaticKernels.ExtensionNothing())
 
-    #g[i] = md.f
-
-    # TODO: initialize g per subdomain with partition function
-
+    # precomputed global B
     B = inv(md.A' * md.A + md.β * I)
 
-    # create models for subproblems
+    # create subproblem contexts
     # TODO: extraction of B subparts only makes sense for blockdiagonal B (i.e. A too)
     li = LinearIndices(size(md.f))
-    models = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
+    submds = [OpROFModel(subg[i], B[vec(li[subax[i]...]), vec(li[subax[i]...])], subα[i])
         for i in CartesianIndices(subax)]
-
     subalg = ChambolleAlgorithm()
+    subctx = [init(submds[i], subalg) for i in CartesianIndices(subax)]
 
-    subctx = [init(models[i], subalg) for i in CartesianIndices(subax)]
-
-    return DualTVDDContext(md, alg, g, p, B, subax, subg, subg, subctx)
+    return DualTVDDContext(md, alg, g, p, ptmp, B, subax, subg, subctx)
 end
 
 function step!(ctx::DualTVDDContext)
@@ -72,30 +70,23 @@ function step!(ctx::DualTVDDContext)
     @inline kfΛ(w) = @inbounds -divergence_global(w)
     kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
 
-    println("global p")
-    display(ctx.p)
-
     # call run! on each cell (this can be threaded)
     for i in eachindex(ctx.subctx)
         sax = ctx.subax[i]
         ci = CartesianIndices(sax)
 
+        # TODO: make p computation local!
         # g_i = (A*f - Λ(1-theta_i)p^n)|_{\Omega_i}
         # subctx[i].p is used as a buffer
 
-        tmp = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
+        ctx.ptmp .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p
         #tmp3 = .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)))
         #ctx.subctx[i].p .= .-(1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), ci)) .* ctx.p[ctx.subax[i]...]
 
-        tmp2 = map(kΛ, extend(tmp, StaticKernels.ExtensionNothing()))
+        tmp2 = map(kΛ, ctx.ptmp)
         ctx.subg[i] .= tmp2[sax...]
         #map!(kΛ, ctx.subg[i], ctx.subctx[i].p)
 
-        println("### ITERATION $i ###")
-        display(tmp)
-        #display(ctx.subctx[i].p)
-        display(ctx.subg[i])
-
         ctx.subg[i] .+= ctx.g[sax...]
         # set sensible starting value
         ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p)))
@@ -136,10 +127,10 @@ function recover_u!(ctx::DualTVDDContext)
     @inline kfΛ(w) = @inbounds divergence(w)
     kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
 
-    # u = div(p) + A'*f
+    # v = div(p) + A'*f
     map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing()))
     v .+= ctx.g
-    # u = B * u
+    # u = B * v
     mul!(vec(u), ctx.B, vec(v))
     return u
 end