diff --git a/src/mesh.jl b/src/mesh.jl
index eb3a1688118873b8d88084ac8657d4452c396674..9983de37dafd9e928b4d0f30390de75e52b21328 100644
--- a/src/mesh.jl
+++ b/src/mesh.jl
@@ -6,6 +6,21 @@ using StaticArrays: SVector
 struct HMesh
     vertices::Vector{SVector{2, Float64}}
     cells::Vector{NTuple{3, Int}}
+    levels::Vector{Int}
+end
+
+function HMesh(mesh::Mesh)
+    vertices = Vector{SVector{2, Float64}}()
+    for v in axes(mesh.vertices, 2)
+	push!(vertices, mesh.vertices[:, v])
+    end
+    cells = Vector{NTuple{3, Int}}()
+    for c in axes(mesh.cells, 2)
+	push!(cells, NTuple{3}(mesh.cells[:, c]))
+    end
+    levels = zeros(Int, axes(cells))
+
+    return HMesh(vertices, cells, levels)
 end
 
 Base.show(io::IO, ::MIME"text/plain", x::HMesh) =
@@ -17,7 +32,7 @@ nvertices_cell(::HMesh) = 3
 nvertices(x::HMesh) = length(x.vertices)
 ncells(x::HMesh) = length(x.cells)
 
-cells(mesh::HMesh) = axes(mesh.cells, 1)
+cells(mesh::HMesh) = setdiff(axes(mesh.cells, 1), findall(>(0), mesh.levels))
 vertices(mesh::HMesh, cell) = mesh.cells[cell]
 
 function init_hgrid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
@@ -37,7 +52,7 @@ function init_hgrid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
 	push!(cells, (vidx[I], vidx[I + e1 + e2], vidx[I + e2]))
     end
 
-    return HMesh(vertices, cells)
+    return HMesh(vertices, cells, zeros(Int, axes(cells)))
 end
 
 init_hgrid(img::Array{<:Any, 2}; type=:vertex) =
@@ -45,6 +60,13 @@ init_hgrid(img::Array{<:Any, 2}; type=:vertex) =
 	init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) :
 	init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5))
 
+function sub_mesh(hmesh::HMesh, subcells = cells(hmesh))
+    cells = collect(reshape(reinterpret(Int, hmesh.cells[subcells]), 3, :))
+    # TODO: restrict vertices too
+    vertices = collect(reshape(reinterpret(Float64, hmesh.vertices), 2, :))
+    return Mesh(vertices, cells)
+end
+
 
 # 2d, simplex grid
 struct Mesh
@@ -140,8 +162,9 @@ init_grid(img::Array{<:Any, 2}; type=:vertex) =
 	init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) :
 	init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5))
 
-function refine!(mesh::HMesh, marked_cells::Set)
-    refined_cells = Set{Int}()
+# horribly implemented, please don't curse me
+function bisect!(mesh::HMesh, marked_cells::Set)
+    refined_cells = Pair{Int, NTuple{2, Int}}[]
     # assemble edge -> cells map
     edgemap = Dict{NTuple{2, Int}, Vector{Int}}()
     for cell in cells(mesh)
@@ -156,6 +179,8 @@ function refine!(mesh::HMesh, marked_cells::Set)
 	edgemap[e3] = push!(get!(edgemap, e3, []), cell)
     end
 
+    #return edgemap
+
     function refine_cell(c1)
 	c2 = -1
 	# c1 -> c11 + c12
@@ -168,7 +193,7 @@ function refine!(mesh::HMesh, marked_cells::Set)
 	    c2 = c2_arr[begin]
 	    c2_vs = sort(SVector(vertices(mesh, c2)))
 
-	    if c1_vs[end] != c2_vs[end]
+	    if c1_vs[1:2] != c2_vs[1:2]
 		# cannot refine `cellop` compatibly, recurse
 		refine_cell(c2)
 		# refetch c2 because topology has changed
@@ -186,11 +211,15 @@ function refine!(mesh::HMesh, marked_cells::Set)
 	push!(mesh.vertices, xbisect)
 	vbisect = lastindex(mesh.vertices)
 
-	push!(mesh.cells, Tuple(sort(replace(c1_vs, c1_vs[1] => vbisect))))
+	# take care to produce positively oriented cells
+	push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c1)), c1_vs[1] => vbisect)))
+	push!(mesh.levels, 0)
 	c3 = lastindex(mesh.cells)
 	replace!(edgemap[NTuple{2}(setdiff(c1_vs, c1_vs[1]))], c1 => c3)
 
-	push!(mesh.cells, Tuple(sort(replace(c1_vs, c1_vs[2] => vbisect))))
+	# take care to produce positively oriented cells
+	push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c1)), c1_vs[2] => vbisect)))
+	push!(mesh.levels, 0)
 	c4 = lastindex(mesh.cells)
 	replace!(edgemap[NTuple{2}(setdiff(c1_vs, c1_vs[2]))], c1 => c4)
 
@@ -199,15 +228,20 @@ function refine!(mesh::HMesh, marked_cells::Set)
 	edgemap[(c1_vs[2], vbisect)] = [c3]
 	edgemap[(c1_vs[3], vbisect)] = [c3, c4]
 
+	mesh.levels[c1] += 1
 	delete!(marked_cells, c1)
-	push!(refined_cells, c1)
+	push!(refined_cells, c1 => (c3, c4))
 
 	if c2 > 0
-	    push!(mesh.cells, Tuple(sort(replace(c2_vs, c1_vs[1] => vbisect)))) # c1_vs is correct
+	    # take care to produce positively oriented cells
+	    push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c2)), c1_vs[1] => vbisect))) # c1_vs is correct
+	    push!(mesh.levels, 0)
 	    c5 = lastindex(mesh.cells)
 	    replace!(edgemap[NTuple{2}(setdiff(c2_vs, c1_vs[1]))], c1 => c5)
 
-	    push!(mesh.cells, Tuple(sort(replace(c2_vs, c1_vs[2] => vbisect)))) # c1_vs is correct
+	    # take care to produce positively oriented cells
+	    push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c2)), c1_vs[2] => vbisect))) # c1_vs is correct
+	    push!(mesh.levels, 0)
 	    c6 = lastindex(mesh.cells)
 	    replace!(edgemap[NTuple{2}(setdiff(c2_vs, c1_vs[2]))], c1 => c6)
 
@@ -216,8 +250,9 @@ function refine!(mesh::HMesh, marked_cells::Set)
 	    push!(edgemap[(c1_vs[2], vbisect)], c5)
 	    edgemap[(c1_vs[3], vbisect)] = [c5, c6]
 
+	    mesh.levels[c2] += 1
 	    delete!(marked_cells, c2)
-	    push!(refined_cells, c1)
+	    push!(refined_cells, c2 => (c5, c6))
 	end
     end
 
@@ -225,10 +260,103 @@ function refine!(mesh::HMesh, marked_cells::Set)
 	refine_cell(first(marked_cells))
     end
 
-    deleteat!(mesh.cells, sort(collect(refined_cells)))
-    return mesh
+    return refined_cells
+    #deleteat!(mesh.cells, sort(collect(refined_cells)))
+end
+
+# horribly implemented, please don't curse me
+function refine!(hmesh::HMesh, marked_cells::Set; fs...)
+    old_mesh = sub_mesh(hmesh)
+    refined_cells = bisect!(hmesh, marked_cells)
+    extended_mesh = sub_mesh(hmesh, axes(hmesh.cells, 1))
+    new_mesh = sub_mesh(hmesh)
+    removed_cells = map(x -> first(x), refined_cells)
+
+    # extended functions onto newly created cells
+    extended_fs = map(NamedTuple(fs)) do f
+	space = FeSpace(extended_mesh, f.space.element, f.space.size)
+	return FeFunction(space)
+    end
+    # copy over previous data for unmodified cells
+    for (f, extended_f) in zip(NamedTuple(fs), extended_fs)
+	copyto!(extended_f.data, f.data)
+    end
+    # prolong data for refined cells
+    for (old_cell, extended_cells) in refined_cells
+	for (f, extended_f) in zip(NamedTuple(fs), extended_fs)
+	    prolong!(extended_f, old_cell, extended_cells)
+	end
+    end
+
+    # retain only non-refined cells
+    new_fs = map(NamedTuple(extended_fs)) do f
+	space = FeSpace(new_mesh, f.space.element, f.space.size)
+	return FeFunction(space)
+    end
+    retained_cells = setdiff(cells(extended_mesh), removed_cells)
+    @assert(retained_cells == cells(hmesh))
+    for (new_cell, old_cell) in enumerate(retained_cells)
+	for (f, new_f) in zip(extended_fs, new_fs)
+	    gdofs = f.space.dofmap[:, :, old_cell]
+	    new_gdofs = new_f.space.dofmap[:, :, new_cell]
+	    new_f.data[new_gdofs] .= f.data[gdofs]
+	end
+    end
+    return new_fs
+end
+
+prolong!(new_f, old_cell, new_cells) =
+    prolong!(new_f, old_cell, new_cells, new_f.space.element)
+
+prolong!(new_f, old_cell, new_cells, ::DP1) =
+    prolong!(new_f, old_cell, new_cells, P1())
+
+function prolong!(new_f, old_cell, new_cells, ::P1)
+    old_f = new_f
+    old_cell_vs = collect(vertices(old_f.space.mesh, old_cell))
+    new_cell1_vs = collect(vertices(new_f.space.mesh, new_cells[1]))
+    new_cell2_vs = collect(vertices(new_f.space.mesh, new_cells[2]))
+
+    # copy over data for common vertices
+    common_vs = intersect(old_cell_vs, new_cell1_vs)
+    old_ldofs = indexin(common_vs, old_cell_vs)
+    old_gdofs = old_f.space.dofmap[:, old_ldofs, old_cell]
+    new_ldofs = indexin(common_vs, new_cell1_vs)
+    new_gdofs = new_f.space.dofmap[:, new_ldofs, new_cells[1]]
+    new_f.data[new_gdofs] .= old_f.data[old_gdofs]
+
+    common_vs = intersect(old_cell_vs, new_cell2_vs)
+    old_ldofs = indexin(common_vs, old_cell_vs)
+    old_gdofs = old_f.space.dofmap[:, old_ldofs, old_cell]
+    new_ldofs = indexin(common_vs, new_cell2_vs)
+    new_gdofs = new_f.space.dofmap[:, new_ldofs, new_cells[2]]
+    new_f.data[new_gdofs] .= old_f.data[old_gdofs]
+
+    # vertices of bisection edge
+    avg_vs = symdiff(new_cell1_vs, new_cell2_vs)
+    old_ldofs = indexin(avg_vs, old_cell_vs)
+    old_gdofs = old_f.space.dofmap[:, old_ldofs, old_cell]
+
+    avg_data = (old_f.data[old_gdofs[:, 1]] .+ old_f.data[old_gdofs[:, 2]]) ./ 2
+
+    new_gdofs = new_f.space.dofmap[:, 3, new_cells[1]]
+    new_f.data[new_gdofs] .= avg_data
+    new_gdofs = new_f.space.dofmap[:, 3, new_cells[2]]
+    new_f.data[new_gdofs] .= avg_data
+end
+
+function prolong!(new_f, old_cell, new_cells, ::DP0)
+    old_f = new_f
+    # simply copy over the data
+    old_gdofs = old_f.space.dofmap[:, 1, old_cell]
+
+    new_gdofs = new_f.space.dofmap[:, 1, new_cells[1]]
+    new_f.data[new_gdofs] .= old_f.data[old_gdofs]
+    new_gdofs = new_f.space.dofmap[:, 1, new_cells[2]]
+    new_f.data[new_gdofs] .= old_f.data[old_gdofs]
 end
 
+
 #function cell_contains(mesh, cell, v)
 #    geo = mesh.vertices[:, mesh.cells[:, cell]]
 #    J = jacobian(x -> geo * [1 - x[1] - x[2], x[1], x[2]], [0., 0.])
@@ -254,3 +382,33 @@ function save(filename::String, mesh::Mesh, fs...)
     end
     vtk_save(vtk)
 end
+
+function elmap(mesh, cell)
+    A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
+	view(mesh.vertices, :, view(mesh.cells, :, cell)))
+    return x -> A * SA[1 - x[1] - x[2], x[1], x[2]]
+end
+
+function test_mesh(mesh)
+    # assemble edge -> cells map
+    edgemap = Dict{NTuple{2, Int}, Vector{Int}}()
+    for cell in cells(mesh)
+	vs = sort(SVector(vertices(mesh, cell)))
+
+	e1 = (vs[1], vs[2])
+	e2 = (vs[1], vs[3])
+	e3 = (vs[2], vs[3])
+
+	edgemap[e1] = push!(get!(edgemap, e1, []), cell)
+	edgemap[e2] = push!(get!(edgemap, e2, []), cell)
+	edgemap[e3] = push!(get!(edgemap, e3, []), cell)
+    end
+    for (edge, cells) in edgemap
+	@assert(length(cells) <= 2)
+    end
+    # are cells positively oriented?
+    for cell in cells(mesh)
+	delmap = jacobian(elmap(mesh, cell), SA[0., 0.])
+	@assert(det(delmap) > 0)
+    end
+end
diff --git a/src/run.jl b/src/run.jl
index fd9e32ddf79d8480fdb7310fd27cfab7a259f430..cc2bbfd816acaac73faf214fbb65aee3b16cb27c 100644
--- a/src/run.jl
+++ b/src/run.jl
@@ -203,6 +203,37 @@ function estimate!(ctx::L1L2TVContext)
     interpolate!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
 end
 
+function refine(ctx::L1L2TVContext, marked_cells)
+    hmesh = HMesh(ctx.mesh)
+    refined_functions = refine!(hmesh, Set(marked_cells);
+	ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2)
+    new_mesh = refined_functions.u.space.mesh
+
+    new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
+	ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
+
+    @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
+    @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
+    @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
+    @assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap)
+    @assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap)
+    @assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap)
+    @assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap)
+    @assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap)
+
+    new_ctx.est.data .= refined_functions.est.data
+    new_ctx.g.data .= refined_functions.g.data
+    new_ctx.u.data .= refined_functions.u.data
+    new_ctx.p1.data .= refined_functions.p1.data
+    new_ctx.p2.data .= refined_functions.p2.data
+    new_ctx.du.data .= refined_functions.du.data
+    new_ctx.dp1.data .= refined_functions.dp1.data
+    new_ctx.dp2.data .= refined_functions.dp2.data
+
+    return new_ctx
+end
+
+
 function save(ctx::L1L2TVContext, filename, fs...)
     print("save ... ")
     vtk = vtk_mesh(filename, ctx.mesh)
@@ -211,6 +242,69 @@ function save(ctx::L1L2TVContext, filename, fs...)
     return vtk
 end
 
+function test_refine(mesh = nothing)
+    m = 1
+
+    if isnothing(mesh)
+	mesh = init_grid(2, 2)
+    end
+
+    T(tdata, u) = u
+    S(u, nablau) = u
+    ctx = L1L2TVContext("test", mesh, m; T, tdata = nothing, S,
+	alpha1=0., alpha2=5., lambda=0.1, beta=1e-3, gamma1=1e-3, gamma2=1e-3)
+
+    interpolate!(ctx.g, x -> x[1] + x[2] > 1.01)
+
+    mysave(ctx, i) =
+	save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
+	    ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
+
+    pvd = paraview_collection("$(ctx.name).pvd")
+    for i in 1:10
+	step!(ctx)
+    end
+    pvd[0] = mysave(ctx, 0)
+    ctx2 = refine(ctx, [1, 2, 3, 4, 5, 6, 7, 8])
+    #interpolate!(ctx2.g, x -> x[1] + x[2] < 2.01)
+    #interpolate!(ctx2.g, x -> norm(x .- m) < norm(m) / 3)
+    #ctx2.est.data .= 0
+    #ctx2.g.data .= 0
+    #ctx2.u.data .= 0
+    #ctx2.p1.data .= 0
+    #ctx2.p2.data .= 0
+    #ctx2.du.data .= 0
+    #ctx2.dp1.data .= 0
+    #ctx2.dp2.data .= 0
+    for i in 1:10
+	step!(ctx2)
+    end
+    pvd[1] = mysave(ctx2, 1)
+
+    test_mesh(ctx2.mesh)
+
+    return ctx, ctx2
+    for i in 1:5
+	step!(ctx)
+	estimate!(ctx)
+	pvd[i] = mysave(i)
+	println()
+    end
+    estmean = mean(ctx.est.data)
+    marked_cells = findall(>(2*estmean), ctx.est.data)
+    println(marked_cells)
+    ctx = refine(ctx, marked_cells)
+    interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
+    for i in 6:11
+	step!(ctx)
+	estimate!(ctx)
+	pvd[i] = mysave(i)
+	println()
+    end
+
+    return ctx
+end
+
 function denoise(img; name, params...)
     m = 1
     mesh = init_grid(img; type=:vertex)
@@ -236,6 +330,20 @@ function denoise(img; name, params...)
 	pvd[i] = save_denoise(i)
 	println()
     end
+    estmean = mean(ctx.est.data)
+    marked_cells = findall(>(2*estmean), ctx.est.data)
+    test_mesh(ctx.mesh)
+    println(marked_cells)
+    ctx = refine(ctx, marked_cells)
+    test_mesh(ctx.mesh)
+    interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
+    for i in 6:20
+        step!(ctx)
+        estimate!(ctx)
+        pvd[i] = save_denoise(i)
+        println()
+    end
+
     return ctx
 end
 
@@ -266,7 +374,7 @@ function inpaint(img, imgmask; name, params...)
 
     pvd = paraview_collection("$(ctx.name).pvd")
     pvd[0] = save_inpaint(0)
-    for i in 1:20
+    for i in 1:3
 	step!(ctx)
 	estimate!(ctx)
 	pvd[i] = save_inpaint(i)