diff --git a/src/function.jl b/src/function.jl
index 22f25a83618a6f1c0d4b8b375eeb49e45154b5ab..1b880d343471a8e360942b082bdd41379a00fcf1 100644
--- a/src/function.jl
+++ b/src/function.jl
@@ -307,14 +307,13 @@ end
 bind!(f::FacetDivergence, cell) = f.cell[] = cell
 
 
-prolong!(new_f, old_cell, new_cells) =
-    prolong!(new_f, old_cell, new_cells, new_f.space.element)
+prolong!(new_f, old_f, old_cell, new_cells) =
+    prolong!(new_f, old_f, old_cell, new_cells, new_f.space.element)
 
-prolong!(new_f, old_cell, new_cells, ::DP1) =
-    prolong!(new_f, old_cell, new_cells, P1())
+prolong!(new_f, old_f, old_cell, new_cells, ::DP1) =
+    prolong!(new_f, old_f, old_cell, new_cells, P1())
 
-function prolong!(new_f, old_cell, new_cells, ::P1)
-    old_f = new_f
+function prolong!(new_f, old_f, old_cell, new_cells, ::P1)
     old_cell_vs = collect(vertices(old_f.space.mesh, old_cell))
     new_cell1_vs = collect(vertices(new_f.space.mesh, new_cells[1]))
     new_cell2_vs = collect(vertices(new_f.space.mesh, new_cells[2]))
@@ -350,8 +349,7 @@ function prolong!(new_f, old_cell, new_cells, ::P1)
     new_f.data[new_gdofs] .= avg_data
 end
 
-function prolong!(new_f, old_cell, new_cells, ::DP0)
-    old_f = new_f
+function prolong!(new_f, old_f, old_cell, new_cells, ::DP0)
     # simply copy over the data
     old_gdofs = old_f.space.dofmap[:, 1, old_cell]
 
diff --git a/src/mesh.jl b/src/mesh.jl
index e3331bd15787cb8f34d36d5f4ff686f7cb1dc4b3..d882c262a772fb59091d094c1363a5d0b726e87b 100644
--- a/src/mesh.jl
+++ b/src/mesh.jl
@@ -9,6 +9,7 @@ using StaticArrays: SVector
 struct HMesh
     vertices::Vector{SVector{2, Float64}}
     cells::Vector{NTuple{3, Int}}
+    "refinement depth below each cell"
     levels::Vector{Int}
 end
 
@@ -23,7 +24,8 @@ nvertices_cell(::HMesh) = 3
 nvertices(x::HMesh) = length(x.vertices)
 ncells(x::HMesh) = length(x.cells)
 
-cells(mesh::HMesh) = setdiff(axes(mesh.cells, 1), findall(>(0), mesh.levels))
+# fetches finest level cells
+cells(mesh::HMesh) = findall(==(0), mesh.levels)
 vertices(mesh::HMesh, cell) = mesh.cells[cell]
 
 function init_hgrid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
@@ -103,6 +105,8 @@ vertices(mesh::Mesh, cell) = ntuple(i -> mesh.cells[i, cell], nvertices_cell(mes
 #    end
 #end
 
+# old grid initialization
+# produces regular grid but not suitable for newest vertex bisection
 function init_grid_old(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
     r1 = LinRange(v0[1], v1[1], m + 1)
     r2 = LinRange(v0[2], v1[2], n + 1)
@@ -123,6 +127,8 @@ function init_grid_old(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
     return Mesh(vertices, cells)
 end
 
+# new grid initialization
+# produces regular grid suitable for newest vertex bisection
 function init_grid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
     r1 = LinRange(v0[1], v1[1], m + 1)
     r2 = LinRange(v0[2], v1[2], n + 1)
@@ -286,13 +292,18 @@ function bisect!(mesh::HMesh, marked_cells::Set)
     #deleteat!(mesh.cells, sort(collect(refined_cells)))
 end
 
-# horribly implemented, please don't curse me
+# TODO: cleanup and optimize
+# FIXME: this assumes hmesh was created from f.mesh
 function refine!(hmesh::HMesh, marked_cells::Set; fs...)
     old_mesh = sub_mesh(hmesh)
-    refined_cells = bisect!(hmesh, marked_cells)
+    cell_refinements = bisect!(hmesh, marked_cells)
+    # we use a extended mesh containing all cells (both fine and
+    # coarse) since a cell might have been refinemened multiple times and the
+    # intermediate cells will not be present in the final mesh for direct
+    # prolongation of data
     extended_mesh = sub_mesh(hmesh, axes(hmesh.cells, 1))
     new_mesh = sub_mesh(hmesh)
-    removed_cells = map(x -> first(x), refined_cells)
+    refined_cells = map(x -> first(x), cell_refinements)
 
     # extended functions onto newly created cells
     extended_fs = map(NamedTuple(fs)) do f
@@ -304,9 +315,9 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
 	copyto!(extended_f.data, f.data)
     end
     # prolong data for refined cells
-    for (old_cell, extended_cells) in refined_cells
+    for (old_cell, extended_cells) in cell_refinements
 	for extended_f in extended_fs
-	    prolong!(extended_f, old_cell, extended_cells)
+	    prolong!(extended_f, extended_f, old_cell, extended_cells)
 	end
     end
 
@@ -315,7 +326,7 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
 	space = FeSpace(new_mesh, f.space.element, f.space.size)
 	return FeFunction(space; f.name)
     end
-    retained_cells = setdiff(cells(extended_mesh), removed_cells)
+    retained_cells = setdiff(cells(extended_mesh), refined_cells)
     @assert(retained_cells == cells(hmesh))
     for (new_cell, old_cell) in enumerate(retained_cells)
 	for (f, new_f) in zip(extended_fs, new_fs)