Skip to content
Snippets Groups Projects
Commit ddc21ec7 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

implement newest vertex adaptive refinement

parent 41dc9d35
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,21 @@ using StaticArrays: SVector
struct HMesh
vertices::Vector{SVector{2, Float64}}
cells::Vector{NTuple{3, Int}}
levels::Vector{Int}
end
function HMesh(mesh::Mesh)
vertices = Vector{SVector{2, Float64}}()
for v in axes(mesh.vertices, 2)
push!(vertices, mesh.vertices[:, v])
end
cells = Vector{NTuple{3, Int}}()
for c in axes(mesh.cells, 2)
push!(cells, NTuple{3}(mesh.cells[:, c]))
end
levels = zeros(Int, axes(cells))
return HMesh(vertices, cells, levels)
end
Base.show(io::IO, ::MIME"text/plain", x::HMesh) =
......@@ -17,7 +32,7 @@ nvertices_cell(::HMesh) = 3
nvertices(x::HMesh) = length(x.vertices)
ncells(x::HMesh) = length(x.cells)
cells(mesh::HMesh) = axes(mesh.cells, 1)
cells(mesh::HMesh) = setdiff(axes(mesh.cells, 1), findall(>(0), mesh.levels))
vertices(mesh::HMesh, cell) = mesh.cells[cell]
function init_hgrid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
......@@ -37,7 +52,7 @@ function init_hgrid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
push!(cells, (vidx[I], vidx[I + e1 + e2], vidx[I + e2]))
end
return HMesh(vertices, cells)
return HMesh(vertices, cells, zeros(Int, axes(cells)))
end
init_hgrid(img::Array{<:Any, 2}; type=:vertex) =
......@@ -45,6 +60,13 @@ init_hgrid(img::Array{<:Any, 2}; type=:vertex) =
init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) :
init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5))
function sub_mesh(hmesh::HMesh, subcells = cells(hmesh))
cells = collect(reshape(reinterpret(Int, hmesh.cells[subcells]), 3, :))
# TODO: restrict vertices too
vertices = collect(reshape(reinterpret(Float64, hmesh.vertices), 2, :))
return Mesh(vertices, cells)
end
# 2d, simplex grid
struct Mesh
......@@ -140,8 +162,9 @@ init_grid(img::Array{<:Any, 2}; type=:vertex) =
init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) :
init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5))
function refine!(mesh::HMesh, marked_cells::Set)
refined_cells = Set{Int}()
# horribly implemented, please don't curse me
function bisect!(mesh::HMesh, marked_cells::Set)
refined_cells = Pair{Int, NTuple{2, Int}}[]
# assemble edge -> cells map
edgemap = Dict{NTuple{2, Int}, Vector{Int}}()
for cell in cells(mesh)
......@@ -156,6 +179,8 @@ function refine!(mesh::HMesh, marked_cells::Set)
edgemap[e3] = push!(get!(edgemap, e3, []), cell)
end
#return edgemap
function refine_cell(c1)
c2 = -1
# c1 -> c11 + c12
......@@ -168,7 +193,7 @@ function refine!(mesh::HMesh, marked_cells::Set)
c2 = c2_arr[begin]
c2_vs = sort(SVector(vertices(mesh, c2)))
if c1_vs[end] != c2_vs[end]
if c1_vs[1:2] != c2_vs[1:2]
# cannot refine `cellop` compatibly, recurse
refine_cell(c2)
# refetch c2 because topology has changed
......@@ -186,11 +211,15 @@ function refine!(mesh::HMesh, marked_cells::Set)
push!(mesh.vertices, xbisect)
vbisect = lastindex(mesh.vertices)
push!(mesh.cells, Tuple(sort(replace(c1_vs, c1_vs[1] => vbisect))))
# take care to produce positively oriented cells
push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c1)), c1_vs[1] => vbisect)))
push!(mesh.levels, 0)
c3 = lastindex(mesh.cells)
replace!(edgemap[NTuple{2}(setdiff(c1_vs, c1_vs[1]))], c1 => c3)
push!(mesh.cells, Tuple(sort(replace(c1_vs, c1_vs[2] => vbisect))))
# take care to produce positively oriented cells
push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c1)), c1_vs[2] => vbisect)))
push!(mesh.levels, 0)
c4 = lastindex(mesh.cells)
replace!(edgemap[NTuple{2}(setdiff(c1_vs, c1_vs[2]))], c1 => c4)
......@@ -199,15 +228,20 @@ function refine!(mesh::HMesh, marked_cells::Set)
edgemap[(c1_vs[2], vbisect)] = [c3]
edgemap[(c1_vs[3], vbisect)] = [c3, c4]
mesh.levels[c1] += 1
delete!(marked_cells, c1)
push!(refined_cells, c1)
push!(refined_cells, c1 => (c3, c4))
if c2 > 0
push!(mesh.cells, Tuple(sort(replace(c2_vs, c1_vs[1] => vbisect)))) # c1_vs is correct
# take care to produce positively oriented cells
push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c2)), c1_vs[1] => vbisect))) # c1_vs is correct
push!(mesh.levels, 0)
c5 = lastindex(mesh.cells)
replace!(edgemap[NTuple{2}(setdiff(c2_vs, c1_vs[1]))], c1 => c5)
push!(mesh.cells, Tuple(sort(replace(c2_vs, c1_vs[2] => vbisect)))) # c1_vs is correct
# take care to produce positively oriented cells
push!(mesh.cells, Tuple(replace(SVector(vertices(mesh, c2)), c1_vs[2] => vbisect))) # c1_vs is correct
push!(mesh.levels, 0)
c6 = lastindex(mesh.cells)
replace!(edgemap[NTuple{2}(setdiff(c2_vs, c1_vs[2]))], c1 => c6)
......@@ -216,8 +250,9 @@ function refine!(mesh::HMesh, marked_cells::Set)
push!(edgemap[(c1_vs[2], vbisect)], c5)
edgemap[(c1_vs[3], vbisect)] = [c5, c6]
mesh.levels[c2] += 1
delete!(marked_cells, c2)
push!(refined_cells, c1)
push!(refined_cells, c2 => (c5, c6))
end
end
......@@ -225,10 +260,103 @@ function refine!(mesh::HMesh, marked_cells::Set)
refine_cell(first(marked_cells))
end
deleteat!(mesh.cells, sort(collect(refined_cells)))
return mesh
return refined_cells
#deleteat!(mesh.cells, sort(collect(refined_cells)))
end
# horribly implemented, please don't curse me
function refine!(hmesh::HMesh, marked_cells::Set; fs...)
old_mesh = sub_mesh(hmesh)
refined_cells = bisect!(hmesh, marked_cells)
extended_mesh = sub_mesh(hmesh, axes(hmesh.cells, 1))
new_mesh = sub_mesh(hmesh)
removed_cells = map(x -> first(x), refined_cells)
# extended functions onto newly created cells
extended_fs = map(NamedTuple(fs)) do f
space = FeSpace(extended_mesh, f.space.element, f.space.size)
return FeFunction(space)
end
# copy over previous data for unmodified cells
for (f, extended_f) in zip(NamedTuple(fs), extended_fs)
copyto!(extended_f.data, f.data)
end
# prolong data for refined cells
for (old_cell, extended_cells) in refined_cells
for (f, extended_f) in zip(NamedTuple(fs), extended_fs)
prolong!(extended_f, old_cell, extended_cells)
end
end
# retain only non-refined cells
new_fs = map(NamedTuple(extended_fs)) do f
space = FeSpace(new_mesh, f.space.element, f.space.size)
return FeFunction(space)
end
retained_cells = setdiff(cells(extended_mesh), removed_cells)
@assert(retained_cells == cells(hmesh))
for (new_cell, old_cell) in enumerate(retained_cells)
for (f, new_f) in zip(extended_fs, new_fs)
gdofs = f.space.dofmap[:, :, old_cell]
new_gdofs = new_f.space.dofmap[:, :, new_cell]
new_f.data[new_gdofs] .= f.data[gdofs]
end
end
return new_fs
end
prolong!(new_f, old_cell, new_cells) =
prolong!(new_f, old_cell, new_cells, new_f.space.element)
prolong!(new_f, old_cell, new_cells, ::DP1) =
prolong!(new_f, old_cell, new_cells, P1())
function prolong!(new_f, old_cell, new_cells, ::P1)
old_f = new_f
old_cell_vs = collect(vertices(old_f.space.mesh, old_cell))
new_cell1_vs = collect(vertices(new_f.space.mesh, new_cells[1]))
new_cell2_vs = collect(vertices(new_f.space.mesh, new_cells[2]))
# copy over data for common vertices
common_vs = intersect(old_cell_vs, new_cell1_vs)
old_ldofs = indexin(common_vs, old_cell_vs)
old_gdofs = old_f.space.dofmap[:, old_ldofs, old_cell]
new_ldofs = indexin(common_vs, new_cell1_vs)
new_gdofs = new_f.space.dofmap[:, new_ldofs, new_cells[1]]
new_f.data[new_gdofs] .= old_f.data[old_gdofs]
common_vs = intersect(old_cell_vs, new_cell2_vs)
old_ldofs = indexin(common_vs, old_cell_vs)
old_gdofs = old_f.space.dofmap[:, old_ldofs, old_cell]
new_ldofs = indexin(common_vs, new_cell2_vs)
new_gdofs = new_f.space.dofmap[:, new_ldofs, new_cells[2]]
new_f.data[new_gdofs] .= old_f.data[old_gdofs]
# vertices of bisection edge
avg_vs = symdiff(new_cell1_vs, new_cell2_vs)
old_ldofs = indexin(avg_vs, old_cell_vs)
old_gdofs = old_f.space.dofmap[:, old_ldofs, old_cell]
avg_data = (old_f.data[old_gdofs[:, 1]] .+ old_f.data[old_gdofs[:, 2]]) ./ 2
new_gdofs = new_f.space.dofmap[:, 3, new_cells[1]]
new_f.data[new_gdofs] .= avg_data
new_gdofs = new_f.space.dofmap[:, 3, new_cells[2]]
new_f.data[new_gdofs] .= avg_data
end
function prolong!(new_f, old_cell, new_cells, ::DP0)
old_f = new_f
# simply copy over the data
old_gdofs = old_f.space.dofmap[:, 1, old_cell]
new_gdofs = new_f.space.dofmap[:, 1, new_cells[1]]
new_f.data[new_gdofs] .= old_f.data[old_gdofs]
new_gdofs = new_f.space.dofmap[:, 1, new_cells[2]]
new_f.data[new_gdofs] .= old_f.data[old_gdofs]
end
#function cell_contains(mesh, cell, v)
# geo = mesh.vertices[:, mesh.cells[:, cell]]
# J = jacobian(x -> geo * [1 - x[1] - x[2], x[1], x[2]], [0., 0.])
......@@ -254,3 +382,33 @@ function save(filename::String, mesh::Mesh, fs...)
end
vtk_save(vtk)
end
function elmap(mesh, cell)
A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
view(mesh.vertices, :, view(mesh.cells, :, cell)))
return x -> A * SA[1 - x[1] - x[2], x[1], x[2]]
end
function test_mesh(mesh)
# assemble edge -> cells map
edgemap = Dict{NTuple{2, Int}, Vector{Int}}()
for cell in cells(mesh)
vs = sort(SVector(vertices(mesh, cell)))
e1 = (vs[1], vs[2])
e2 = (vs[1], vs[3])
e3 = (vs[2], vs[3])
edgemap[e1] = push!(get!(edgemap, e1, []), cell)
edgemap[e2] = push!(get!(edgemap, e2, []), cell)
edgemap[e3] = push!(get!(edgemap, e3, []), cell)
end
for (edge, cells) in edgemap
@assert(length(cells) <= 2)
end
# are cells positively oriented?
for cell in cells(mesh)
delmap = jacobian(elmap(mesh, cell), SA[0., 0.])
@assert(det(delmap) > 0)
end
end
......@@ -203,6 +203,37 @@ function estimate!(ctx::L1L2TVContext)
interpolate!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
end
function refine(ctx::L1L2TVContext, marked_cells)
hmesh = HMesh(ctx.mesh)
refined_functions = refine!(hmesh, Set(marked_cells);
ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2)
new_mesh = refined_functions.u.space.mesh
new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
@assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
@assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
@assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
@assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap)
@assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap)
@assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap)
@assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap)
@assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap)
new_ctx.est.data .= refined_functions.est.data
new_ctx.g.data .= refined_functions.g.data
new_ctx.u.data .= refined_functions.u.data
new_ctx.p1.data .= refined_functions.p1.data
new_ctx.p2.data .= refined_functions.p2.data
new_ctx.du.data .= refined_functions.du.data
new_ctx.dp1.data .= refined_functions.dp1.data
new_ctx.dp2.data .= refined_functions.dp2.data
return new_ctx
end
function save(ctx::L1L2TVContext, filename, fs...)
print("save ... ")
vtk = vtk_mesh(filename, ctx.mesh)
......@@ -211,6 +242,69 @@ function save(ctx::L1L2TVContext, filename, fs...)
return vtk
end
function test_refine(mesh = nothing)
m = 1
if isnothing(mesh)
mesh = init_grid(2, 2)
end
T(tdata, u) = u
S(u, nablau) = u
ctx = L1L2TVContext("test", mesh, m; T, tdata = nothing, S,
alpha1=0., alpha2=5., lambda=0.1, beta=1e-3, gamma1=1e-3, gamma2=1e-3)
interpolate!(ctx.g, x -> x[1] + x[2] > 1.01)
mysave(ctx, i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("$(ctx.name).pvd")
for i in 1:10
step!(ctx)
end
pvd[0] = mysave(ctx, 0)
ctx2 = refine(ctx, [1, 2, 3, 4, 5, 6, 7, 8])
#interpolate!(ctx2.g, x -> x[1] + x[2] < 2.01)
#interpolate!(ctx2.g, x -> norm(x .- m) < norm(m) / 3)
#ctx2.est.data .= 0
#ctx2.g.data .= 0
#ctx2.u.data .= 0
#ctx2.p1.data .= 0
#ctx2.p2.data .= 0
#ctx2.du.data .= 0
#ctx2.dp1.data .= 0
#ctx2.dp2.data .= 0
for i in 1:10
step!(ctx2)
end
pvd[1] = mysave(ctx2, 1)
test_mesh(ctx2.mesh)
return ctx, ctx2
for i in 1:5
step!(ctx)
estimate!(ctx)
pvd[i] = mysave(i)
println()
end
estmean = mean(ctx.est.data)
marked_cells = findall(>(2*estmean), ctx.est.data)
println(marked_cells)
ctx = refine(ctx, marked_cells)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
for i in 6:11
step!(ctx)
estimate!(ctx)
pvd[i] = mysave(i)
println()
end
return ctx
end
function denoise(img; name, params...)
m = 1
mesh = init_grid(img; type=:vertex)
......@@ -236,6 +330,20 @@ function denoise(img; name, params...)
pvd[i] = save_denoise(i)
println()
end
estmean = mean(ctx.est.data)
marked_cells = findall(>(2*estmean), ctx.est.data)
test_mesh(ctx.mesh)
println(marked_cells)
ctx = refine(ctx, marked_cells)
test_mesh(ctx.mesh)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
for i in 6:20
step!(ctx)
estimate!(ctx)
pvd[i] = save_denoise(i)
println()
end
return ctx
end
......@@ -266,7 +374,7 @@ function inpaint(img, imgmask; name, params...)
pvd = paraview_collection("$(ctx.name).pvd")
pvd[0] = save_inpaint(0)
for i in 1:20
for i in 1:3
step!(ctx)
estimate!(ctx)
pvd[i] = save_inpaint(i)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment