diff --git a/scripts/run_experiments.jl b/scripts/run_experiments.jl
index ab18a632bc34114e8c702e0a2a7e9a3d0fc60e2a..ea30e48179f8e8bf0d0396ffb5f43d1ce60abda5 100644
--- a/scripts/run_experiments.jl
+++ b/scripts/run_experiments.jl
@@ -609,21 +609,45 @@ end
 
 huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
 
-# TODO: finish!
-function refine_and_estimate_pd(st::L1L2TVState)
-    # globally refine
+function globally_refine!(hmesh::HMesh, st::L1L2TVState)
     marked_cells = Set(axes(st.mesh.cells, 2))
-    mesh_new, fs_new = refine(st.mesh, marked_cells;
-        st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
+    mesh_new, fs_new = refine!(hmesh, marked_cells;
+        st.tdata, st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2)
     st_new = L1L2TVState(st; mesh = mesh_new,
+        fs_new.tdata,
         fs_new.est, fs_new.g, fs_new.u, fs_new.p1, fs_new.p2,
         fs_new.du, fs_new.dp1, fs_new.dp2)
+    return st_new
+end
+
+function refine_and_estimate_pd!(st, ctx, i, interpolate_image_data!, f0 = nothing, f1 = nothing)
+    hmesh = HMesh(st.mesh)
+    hmesh_coarse_cells = cells(hmesh)
+
+    st_new = globally_refine!(hmesh, st)
 
     # compute primal dual error indicators
+    if isnothing(f0) && isnothing(f1)
+        interpolate_image_data!(st_new)
+    else
+        # TODO: we should maybe create a distinct state object for optflow and
+        # have f0, f1 be part of it to avoid this special casing.
+        interpolate_image_data!(f0, f1)
+    end
     estimate_pd!(st_new)
 
+    # Output optional fine data
+    if i == 1
+        # Hack to get numbering from 0 as the other vtu data.
+        output(st_new, joinpath(ctx.outdir, "output_pd_fine_$(lpad(0, 5, '0')).vtu"), st_new.g, st_new.est)
+    end
+    output(st_new, joinpath(ctx.outdir, "output_pd_fine_$(lpad(i, 5, '0')).vtu"), st_new.g, st_new.est)
+
     # transfer data to old state
-    #
+    mesh_coarse, fs_coarse = coarsen(hmesh, hmesh_coarse_cells; st_new.est)
+    @assert cells(st.mesh) == cells(mesh_coarse)
+
+    st.est.data .= fs_coarse.est.data
 end
 
 # this computes the primal-dual error indicator which is not really useful
@@ -1359,7 +1383,7 @@ function inpaint(ctx)
         ctx.params.lambda, ctx.params.beta,
         ctx.params.gamma1, ctx.params.gamma2)
 
-    function interpolate_image_data!()
+    function interpolate_image_data!(st)
         println("interpolate image data ...")
         if ctx.params.n_refine == 0
             # if we use a grid at image resolution, we wan't to avoid
@@ -1381,7 +1405,7 @@ function inpaint(ctx)
 
     pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd
 
-        interpolate_image_data!()
+        interpolate_image_data!(st)
 
         pvd[0] = save_step(0)
 
@@ -1403,7 +1427,11 @@ function inpaint(ctx)
             # plot
             i += 1
             display(plot(grayclamp.(to_img(sample(st.u)))))
-            estimate_res!(st)
+            if hasproperty(ctx.params, :estimator) && ctx.params.estimator == :primal_dual
+                refine_and_estimate_pd!(st, ctx, i, interpolate_image_data!)
+            else
+                estimate_res!(st)
+            end
             pvd[i] = save_step(i)
             #break
 
@@ -1424,11 +1452,15 @@ function inpaint(ctx)
 
             # manually mark all cell within inpainting domain, since the
             # estimator is not reliable there
-            for cell in cells(mesh)
-                bind!(st.tdata, cell)
-                maskv = evaluate(st.tdata, SA[1/3, 1/3])
-                if abs(maskv[begin] - 1.) > 1e-8
-                    push!(marked_cells, cell)
+            if !hasproperty(ctx.params, :estimator) || ctx.params.estimator != :primal_dual
+                # The primal-dual estimator seems reliable inside the
+                # inpainting domain
+                for cell in cells(mesh)
+                    bind!(st.tdata, cell)
+                    maskv = evaluate(st.tdata, SA[1/3, 1/3])
+                    if abs(maskv[begin] - 1.) > 1e-8
+                        push!(marked_cells, cell)
+                    end
                 end
             end
             #marked_cells = Set(axes(mesh.cells, 2))
@@ -1436,7 +1468,7 @@ function inpaint(ctx)
                 st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2, st.tdata)
             st = L1L2TVState(st; mesh, fs.tdata,
                 fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
-            interpolate_image_data!()
+            interpolate_image_data!(st)
         end
     end
     end # @elapsed
@@ -1462,6 +1494,7 @@ function experiment_inpaint_adaptive(ctx)
         params = (
             name = "test",
             n_refine = coarsen,
+            #estimator = :primal_dual,
             g_arr, mask_arr, mesh,
             #alpha1 = 0.2, alpha2 = 8., lambda = 1.,
             alpha1 = 0., alpha2 = 50., lambda = 1.,
@@ -1582,7 +1615,7 @@ function optflow(ctx)
         interpolate!(st.g, g_optflow; u0 = st.u, f0, fw, st.tdata)
     end
 
-    function interpolate_image_data!()
+    function interpolate_image_data!(f0, f1)
         println("interpolate image data ...")
         project_image!(f0, imgf0)
         project_image!(f1, imgf1)
@@ -1596,7 +1629,7 @@ function optflow(ctx)
 
     pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd
 
-        interpolate_image_data!()
+        interpolate_image_data!(f0, f1)
         warp!() # in the first step just to fill st.g
         pvd[0] = save_step(0)
 
@@ -1640,7 +1673,11 @@ function optflow(ctx)
             k_refine += 1
             k_refine > ctx.params.n_refine && break
             println("refine ...")
-            estimate_res!(st)
+            if hasproperty(ctx.params, :estimator) && ctx.params.estimator == :primal_dual
+                refine_and_estimate_pd!(st, ctx, i, interpolate_image_data!, f0, f1)
+            else
+                estimate_res!(st)
+            end
             marked_cells = mark(st; theta = 0.5)
             #marked_cells = Set(axes(mesh.cells, 2))
             mesh, fs = refine(mesh, marked_cells;
@@ -1649,7 +1686,7 @@ function optflow(ctx)
             st = L1L2TVState(st; mesh, fs.tdata,
                 fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
             f0, f1, fw = (fs.f0, fs.f1, fs.fw)
-            interpolate_image_data!()
+            interpolate_image_data!(f0, f1)
         end
     end
     #CSV.write(joinpath(ctx.outdir, "energies.csv"), df)
@@ -1684,7 +1721,7 @@ function experiment_optflow_middlebury(ctx)
     imgf0 = loadimg(joinpath(ctx.indir, "frame10.png"))
     imgf1 = loadimg(joinpath(ctx.indir, "frame11.png"))
     gtflow = FileIO.load(joinpath(ctx.indir, "flow10.flo"))
-    maxflow = OpticalFlowUtils._maxflow(gtflow)
+    maxflow = OpticalFlowUtils.maxflow(gtflow)
 
     ctx = Util.Context(ctx; imgf0, imgf1, maxflow, gtflow)
     saveimg(joinpath(ctx.outdir, "ground_truth.png"), colorflow(gtflow; maxflow))
@@ -1700,12 +1737,6 @@ function experiment_optflow_middlebury_all_benchmarks(ctx)
     end
 end
 
-# FIXME: legacy
-function experiment_optflow_middlebury_all(ctx)
-    experiment_optflow_middlebury_all_benchmarks(
-        Util.Context(ctx; warp = true, refine = true, n_refine = 6))
-end
-
 function experiment_optflow_middlebury_warping_comparison(ctx)
     ctx(experiment_optflow_middlebury_all_benchmarks, "vanilla";
         warp = false, refine = false, n_refine = 0)
@@ -1720,6 +1751,8 @@ function experiment_optflow_middlebury_warping_comparison_adaptive(ctx)
         warp = true, refine = false, n_refine = 0)
     ctx(experiment_optflow_middlebury_all_benchmarks, "adaptive-warping";
         warp = true, refine = true, n_refine = 6)
+    #ctx(experiment_optflow_middlebury_all_benchmarks, "adaptive-warping";
+    #    warp = true, refine = true, n_refine = 6, estimator = :residual)
 end
 
 function experiment_optflow_schoice(ctx)
diff --git a/src/function.jl b/src/function.jl
index 98d56234c9882286eb1d519f88035a75153ffa19..875eb0ebcd201a094ea39928190fdf386972e883 100644
--- a/src/function.jl
+++ b/src/function.jl
@@ -24,7 +24,7 @@ evaluate_basis(::DP1, x) = evaluate_basis(P1(), x)
 struct FeSpace{M, Fe, S}
     mesh::M
     element::Fe
-    dofmap::Array{Int, 3} # (rdim, eldof, cell) -> gdof
+    dofmap::Array{Int, 3} # (rangedim, eldof, cell) -> gdof
     ndofs::Int # = maximum(dofmap)
 end
 
@@ -64,7 +64,7 @@ function FeSpace(mesh, el::DP1, size_=(1,))
 end
 
 Base.show(io::IO, ::MIME"text/plain", x::FeSpace) =
-    print("$(nameof(typeof(x))), $(nameof(typeof(x.element))) elements, size $(x.size), $(ndofs(x)) dofs")
+    print(io, "$(nameof(typeof(x))), $(nameof(typeof(x.element))) elements, size $(x.size), $(ndofs(x)) dofs")
 
 function Base.getproperty(obj::FeSpace{<:Any, <:Any, S}, sym::Symbol) where S
     if sym === :size
@@ -104,7 +104,7 @@ function FeFunction(space; name=string(gensym("f")))
 end
 
 Base.show(io::IO, ::MIME"text/plain", f::FeFunction) =
-    print("$(nameof(typeof(f))), size $(f.space.size) with $(length(f.data)) dofs")
+    print(io, "$(nameof(typeof(f))), size $(f.space.size) with $(length(f.data)) dofs")
 
 interpolate!(dst::FeFunction, expr::Function; params...) =
     interpolate!(dst, dst.space.element, expr; params...)
@@ -256,7 +256,7 @@ struct Derivative{F, M}
 end
 
 Base.show(io::IO, ::MIME"text/plain", x::Derivative) =
-    print("$(nameof(typeof(x))) of $(typeof(x.f))")
+    print(io, "$(nameof(typeof(x))) of $(typeof(x.f))")
 
 # TODO: should be called "derivative"
 function nabla(f)
@@ -380,6 +380,9 @@ function prolong!(new_f, old_f, old_cell, new_cells, ::DP0)
 end
 
 
+# IO
+
+
 vtk_append!(vtkfile, f::FeFunction, fs::FeFunction...) =
     (vtk_append!(vtkfile, f); vtk_append!(vtkfile, fs...))
 
diff --git a/src/mesh.jl b/src/mesh.jl
index a9d4eefc9c0f810360e8a4c18541f50eec3f174b..3c093b72a4d5f4a278e6718a0af435052c1c81ef 100644
--- a/src/mesh.jl
+++ b/src/mesh.jl
@@ -1,5 +1,5 @@
-export init_grid, init_hgrid, save, refine!, refine, cells, vertices,
-    ndims_domain, ndims_space, nvertices, nvertices_cell, ncells
+export HMesh, init_grid, init_hgrid, save, refine!, refine, cells, vertices,
+    ndims_domain, ndims_space, nvertices, nvertices_cell, ncells, coarsen
 
 using LinearAlgebra: norm
 
@@ -9,12 +9,14 @@ using StaticArrays: SVector
 struct HMesh
     vertices::Vector{SVector{2, Float64}}
     cells::Vector{NTuple{3, Int}}
+    "direct descendents of a cell"
+    children::Vector{Union{NTuple{2, Int}, Nothing}}
     "refinement depth below each cell"
     levels::Vector{Int}
 end
 
 Base.show(io::IO, ::MIME"text/plain", x::HMesh) =
-    print("$(nameof(typeof(x))), $(ncells(x)) cells")
+    print(io, "$(nameof(typeof(x))), $(ncells(x)) cells")
 
 ndims_domain(::HMesh) = 2
 ndims_space(::HMesh) = 2
@@ -44,8 +46,10 @@ function init_hgrid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
 	push!(cells, (vidx[I], vidx[I + e1], vidx[I + e1 + e2]))
 	push!(cells, (vidx[I], vidx[I + e1 + e2], vidx[I + e2]))
     end
+    children = fill(nothing, axes(cells))
+    levels = zeros(Int, axes(cells))
 
-    return HMesh(vertices, cells, zeros(Int, axes(cells)))
+    return HMesh(vertices, cells, children, levels)
 end
 
 function init_hgrid(img::Array{<:Any, 2}; type=:vertex)
@@ -79,13 +83,14 @@ function HMesh(mesh::Mesh)
     for c in axes(mesh.cells, 2)
 	push!(cells, NTuple{3}(mesh.cells[:, c]))
     end
+    children = fill(nothing, axes(cells))
     levels = zeros(Int, axes(cells))
 
-    return HMesh(vertices, cells, levels)
+    return HMesh(vertices, cells, children, levels)
 end
 
 Base.show(io::IO, ::MIME"text/plain", x::Mesh) =
-    print("$(nameof(typeof(x))), $(ncells(x)) cells")
+    print(io, "$(nameof(typeof(x))), $(ncells(x)) cells")
 
 ndims_domain(::Mesh) = 2
 ndims_space(::Mesh) = 2
@@ -156,10 +161,10 @@ function init_grid(a::Array{<:Any, 2}, m::Int, n::Int = m; type=:vertex)
         init_grid((m, n)..., (0.5, 0.5), s .- (0.5, 0.5))
 end
 
-# horribly implemented, please don't curse me
+# TODO: cleanup
 function bisect!(mesh::HMesh, marked_cells::Set)
     refined_cells = Pair{Int, NTuple{2, Int}}[]
-    # assemble edge -> cells map
+    # assemble mapping: edge -> cells
     edgemap = Dict{NTuple{2, Int}, Vector{Int}}()
     for cell in cells(mesh)
 	vs = sort(SVector(vertices(mesh, cell)))
@@ -173,8 +178,6 @@ function bisect!(mesh::HMesh, marked_cells::Set)
 	edgemap[e3] = push!(get!(edgemap, e3, []), cell)
     end
 
-    #return edgemap
-
     function refine_cell(c1)
 	c2 = -1
 	# c1 -> c11 + c12
@@ -208,6 +211,7 @@ function bisect!(mesh::HMesh, marked_cells::Set)
 	# take care to produce positively oriented cells
 	push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c1)), c1_vs[1] => vbisect)))
 	push!(mesh.levels, 0)
+	push!(mesh.children, nothing)
 	c3 = lastindex(mesh.cells)
 	@assert(length(setdiff(c1_vs, c1_vs[1])) == 2)
 	replace!(edgemap[NTuple{2}(setdiff(c1_vs, c1_vs[1]))], c1 => c3)
@@ -215,6 +219,7 @@ function bisect!(mesh::HMesh, marked_cells::Set)
 	# take care to produce positively oriented cells
 	push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c1)), c1_vs[2] => vbisect)))
 	push!(mesh.levels, 0)
+	push!(mesh.children, nothing)
 	c4 = lastindex(mesh.cells)
 	@assert(length(setdiff(c1_vs, c1_vs[2])) == 2)
 	replace!(edgemap[NTuple{2}(setdiff(c1_vs, c1_vs[2]))], c1 => c4)
@@ -227,11 +232,13 @@ function bisect!(mesh::HMesh, marked_cells::Set)
 	mesh.levels[c1] += 1
 	delete!(marked_cells, c1)
 	push!(refined_cells, c1 => (c3, c4))
+	mesh.children[c1] = (c3, c4)
 
 	if c2 > 0
 	    # take care to produce positively oriented cells
 	    push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c2)), c1_vs[1] => vbisect))) # c1_vs is correct
 	    push!(mesh.levels, 0)
+	    push!(mesh.children, nothing)
 	    c5 = lastindex(mesh.cells)
 	    @assert(length(setdiff(c2_vs, c1_vs[1])) == 2)
 	    replace!(edgemap[NTuple{2}(setdiff(c2_vs, c1_vs[1]))], c2 => c5)
@@ -239,6 +246,7 @@ function bisect!(mesh::HMesh, marked_cells::Set)
 	    # take care to produce positively oriented cells
 	    push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c2)), c1_vs[2] => vbisect))) # c1_vs is correct
 	    push!(mesh.levels, 0)
+	    push!(mesh.children, nothing)
 	    c6 = lastindex(mesh.cells)
 	    @assert(length(setdiff(c2_vs, c1_vs[2])) == 2)
 	    replace!(edgemap[NTuple{2}(setdiff(c2_vs, c1_vs[2]))], c2 => c6)
@@ -251,6 +259,7 @@ function bisect!(mesh::HMesh, marked_cells::Set)
 	    mesh.levels[c2] += 1
 	    delete!(marked_cells, c2)
 	    push!(refined_cells, c2 => (c5, c6))
+	    mesh.children[c2] = (c5, c6)
 	end
     end
 
@@ -265,46 +274,64 @@ end
 # TODO: cleanup and optimize
 # FIXME: this assumes hmesh was created from f.mesh
 function refine!(hmesh::HMesh, marked_cells::Set; fs...)
+    fs = NamedTuple(fs)
+
     old_mesh = sub_mesh(hmesh)
+    hmesh_old_leaf_cells = cells(hmesh)
+    @assert allequal(x -> x.space.mesh, fs) "Functions living on different meshes."
+    for f in fs
+        # Note this is only a superficial check. We don't spend the effort to
+        # check that meshes match structurally as well.
+        @assert length(cells(f.space.mesh)) == length(hmesh_old_leaf_cells) "Mismatching number of cells for mesh function and lowest layer of hierarchical mesh: $(length(cells(f.space.mesh))) vs $(length(hmesh_leaf_cells))."
+    end
+
     cell_refinements = bisect!(hmesh, marked_cells)
-    # we use a extended mesh containing all cells (both fine and
-    # coarse) since a cell might have been refinemened multiple times and the
-    # intermediate cells will not be present in the final mesh for direct
-    # prolongation of data
+    hmesh_new_leaf_cells = cells(hmesh)
+
+    # We use an extended mesh containing all cells and build functions on top
+    # of that since prolongation across multiple levels becomes more convenient
+    # when having a single set of cells to index with `cell_refinements`.
     extended_mesh = sub_mesh(hmesh, axes(hmesh.cells, 1))
+    # Final leaf mesh, reindexed.
     new_mesh = sub_mesh(hmesh)
+
     refined_cells = map(x -> first(x), cell_refinements)
 
-    # extended functions onto newly created cells
-    extended_fs = map(NamedTuple(fs)) do f
+    # create extended functions on all cells of extended mesh
+    extended_fs = map(fs) do f
 	space = FeSpace(extended_mesh, f.space.element, f.space.size)
 	return FeFunction(space; f.name)
     end
-    # copy over previous data for unmodified cells
-    for (f, extended_f) in zip(NamedTuple(fs), extended_fs)
-	copyto!(extended_f.data, f.data)
+    # copy over data for unmodified cells
+    for (cell, h_cell) in enumerate(hmesh_old_leaf_cells)
+	for (ext_f, f) in zip(extended_fs, fs)
+	    gdofs = f.space.dofmap[:, :, cell]
+	    ext_gdofs = ext_f.space.dofmap[:, :, h_cell]
+	    ext_f.data[ext_gdofs] .= f.data[gdofs]
+	end
     end
     # prolong data for refined cells
+    # TODO: use hmesh.children instead of `cell_refinements`.
     for (old_cell, extended_cells) in cell_refinements
 	for extended_f in extended_fs
 	    prolong!(extended_f, extended_f, old_cell, extended_cells)
 	end
     end
 
-    # retain only non-refined cells
-    new_fs = map(NamedTuple(extended_fs)) do f
+    # create final functions on the new leaf mesh
+    new_fs = map(extended_fs) do f
 	space = FeSpace(new_mesh, f.space.element, f.space.size)
 	return FeFunction(space; f.name)
     end
-    retained_cells = setdiff(cells(extended_mesh), refined_cells)
-    @assert(retained_cells == cells(hmesh))
-    for (new_cell, old_cell) in enumerate(retained_cells)
+    # copy over all leaf cell data into newly indexed mesh
+    for (new_cell, old_cell) in enumerate(hmesh_new_leaf_cells)
 	for (f, new_f) in zip(extended_fs, new_fs)
 	    gdofs = f.space.dofmap[:, :, old_cell]
 	    new_gdofs = new_f.space.dofmap[:, :, new_cell]
 	    new_f.data[new_gdofs] .= f.data[gdofs]
 	end
     end
+
     return new_mesh, new_fs
 end
 
@@ -314,6 +341,47 @@ function refine(mesh::Mesh, marked_cells; fs...)
     return refine!(hmesh, Set(marked_cells); fs...)
 end
 
+"""
+Coarsen functions to a new coarse mesh.
+
+Assumptions:
+  - `hmesh` contains the coarse mesh with the given cell indices,
+  - The meshes of `fs...` are exactly the finest level of `hmesh`.
+"""
+function coarsen(hmesh::HMesh, hmesh_coarse_cells; fs...)
+    fs = NamedTuple(fs)
+    coarse_mesh = sub_mesh(hmesh, hmesh_coarse_cells)
+    hmesh_fine_cells = cells(hmesh)
+
+    coarse_fs = map(fs) do f
+        @assert length(cells(f.space.mesh)) == count(==(0), hmesh.levels)
+	space = FeSpace(coarse_mesh, f.space.element, f.space.size)
+	return FeFunction(space; f.name)
+    end
+
+    function coarse_dofs(hcell, f)
+        if hmesh.levels[hcell] == 0
+            # reverse cell index map
+            fine_cell = findfirst(==(hcell), hmesh_fine_cells)
+            @assert f.space.element isa DP0 "only DP0 implemented"
+            gdofs = f.space.dofmap[:, 1, fine_cell]
+            return f.data[gdofs]
+        else
+            return sum(coarse_dofs(subhcell, f) for subhcell in hmesh.children[hcell]) / length(hmesh.children[hcell])
+        end
+    end
+
+    for coarse_cell in cells(coarse_mesh)
+        for (f, coarse_f) in zip(fs, coarse_fs)
+            # need to map to `hmesh` indexing
+            coarse_gdofs = coarse_f.space.dofmap[:, 1, coarse_cell]
+            coarse_f.data[coarse_gdofs] .= coarse_dofs(hmesh_coarse_cells[coarse_cell], f)
+        end
+    end
+
+    return coarse_mesh, coarse_fs
+end
+
 function geo_tolocal(A, v)
     J = jacobian(x -> A * [1 - x[1] - x[2], x[1], x[2]], [0., 0.])
     return J \ (v - A[:, 1])
diff --git a/test/runtests.jl b/test/runtests.jl
index d3c8789b5a7dd84fc2dc9111417370292cd14909..d108a7eb86469fd4f5a02eced407392f2265100d 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -53,23 +53,38 @@ end
     @test img == img_sampled
 end
 
-@testset "mesh refinement" begin
+@testset "hierarchical mesh" begin
     # simple hashing
     val(f) = integrate(f.space.mesh, (x; f) -> sum(f .* f); f)
 
     mesh = init_grid(zeros(5, 5))
+    hmesh = HMesh(mesh)
 
-    f = FeFunction(FeSpace(mesh, P1(), (2, 3)))
+    f = FeFunction(FeSpace(mesh, DP0(), (2, 3)))
     f.data .= rand(size(f.data)...)
 
-    for i = 1:2
-        # refine roughly half of all cells
-        cells_ref = Set(findall(_ -> rand(Bool), cells(mesh)))
+    hmesh_cells_per_refinement = Vector{Int}[]
+    n = 5
 
-        mesh_new, (f_new,) = refine(mesh, cells_ref; f)
+    @testset "refinement $i" for i in 1:n
+        push!(hmesh_cells_per_refinement, cells(hmesh))
+        # Random set of cells on the finest level.
+        cells_ref = Set(filter(_ -> rand(Bool), cells(hmesh)))
+        mesh_new, (f_new,) = refine!(hmesh, cells_ref; f)
 
-        @test isapprox(val(f), val(f_new))
+        @test all(isnothing, hmesh.children[iszero.(hmesh.levels)])
+        @test all(!isnothing, hmesh.children[hmesh.levels .> 0])
+
+        @test isapprox(val(f_new), val(f))
 
         (mesh, f) = (mesh_new, f_new)
     end
+
+    @testset "coarsening" for i in n:-1:1
+        (mesh_coarse, (f_coarse,)) = coarsen(hmesh, hmesh_cells_per_refinement[i]; f)
+
+        @test isapprox(val(f_coarse), val(f))
+
+        #(mesh, f) = (mesh_coarse, f_coarse)
+    end
 end