Skip to content
Snippets Groups Projects
Commit 65804157 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

get refinement working somewhat

parent 6d2a2f59
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,8 @@ using OpticalFlowUtils
using WriteVTK: paraview_collection
using SemiSmoothNewton
using SemiSmoothNewton: project_img!, project!
using SemiSmoothNewton: HMesh, ncells, refine
using SemiSmoothNewton: project_img!, project_img2!, project!
using SemiSmoothNewton: vtk_mesh, vtk_append!, vtk_save
include("util.jl")
......@@ -63,8 +64,6 @@ struct L1L2TVContext{M, Ttype, Stype}
du::FeFunction
dp1::FeFunction
dp2::FeFunction
nablau
nabladu
end
function L1L2TVContext(name, mesh, m; T, tdata, S,
......@@ -82,11 +81,9 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
u = FeFunction(Vu, name="u")
p1 = FeFunction(Vp1, name="p1")
p2 = FeFunction(Vp2, name="p2")
du = FeFunction(Vu)
dp1 = FeFunction(Vp1)
dp2 = FeFunction(Vp2)
nablau = nabla(u)
nabladu = nabla(du)
du = FeFunction(Vu; name = "du")
dp1 = FeFunction(Vp1; name = "dp1")
dp2 = FeFunction(Vp2; name = "dp2")
est.data .= 0
g.data .= 0
......@@ -99,7 +96,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
return L1L2TVContext(name, mesh, d, m, T, tdata, S,
alpha1, alpha2, beta, lambda, gamma1, gamma2,
est, g, u, p1, p2, du, dp1, dp2, nablau, nabladu)
est, g, u, p1, p2, du, dp1, dp2)
end
function p1_project!(p1, alpha1)
......@@ -166,7 +163,7 @@ function step!(ctx::L1L2TVContext)
# solve du
print("assemble ... ")
A, b = assemble(ctx.du.space, du_a, du_l;
ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata)
ctx.g, ctx.u, nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata)
print("solve ... ")
ctx.du.data .= A \ b
......@@ -191,7 +188,7 @@ function step!(ctx::L1L2TVContext)
return -p2 + lambda / m2 * (nablau + nabladu) - cond
end
interpolate!(ctx.dp2, dp2_update;
ctx.u, ctx.nablau, ctx.p2, ctx.du, ctx.nabladu)
ctx.u, nablau = nabla(ctx.u), ctx.p2, ctx.du, nabladu = nabla(ctx.du))
# newton update
theta = 1.
......@@ -388,46 +385,49 @@ function estimate!(ctx::L1L2TVContext)
end
w = FeFunction(ctx.u.space)
nablaw = nabla(w)
solve_primal!(w, ctx)
project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata)
end
function refine(ctx::L1L2TVContext, marked_cells; fs_...)
fs = NamedTuple(fs_)
hmesh = HMesh(ctx.mesh)
refined_functions = refine!(hmesh, Set(marked_cells);
ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2,
fs...)
new_mesh = refined_functions.u.space.mesh
new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs))
@assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
@assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
@assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
@assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap)
@assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap)
@assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap)
@assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap)
@assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap)
new_ctx.est.data .= refined_functions.est.data
new_ctx.g.data .= refined_functions.g.data
new_ctx.u.data .= refined_functions.u.data
new_ctx.p1.data .= refined_functions.p1.data
new_ctx.p2.data .= refined_functions.p2.data
new_ctx.du.data .= refined_functions.du.data
new_ctx.dp1.data .= refined_functions.dp1.data
new_ctx.dp2.data .= refined_functions.dp2.data
return new_ctx, fs_new
end
project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2,
nablau = nabla(ctx.u), w, nablaw = nabla(w), ctx.tdata)
end
# TODO: deprecate in favor of refine(mesh, marked_cells; fs...)
#function refine(ctx::L1L2TVContext, marked_cells; fs_...)
# fs = NamedTuple(fs_)
#
# hmesh = HMesh(ctx.mesh)
# refined_functions = refine!(hmesh, Set(marked_cells);
# ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2,
# fs...)
# new_mesh = refined_functions.u.space.mesh
#
# # TODO: tdata needs to be recreated for refinement
# new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
# ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
#
# fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs))
#
# @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
# @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
# @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
# @assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap)
# @assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap)
# @assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap)
# @assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap)
# @assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap)
#
# new_ctx.est.data .= refined_functions.est.data
# new_ctx.g.data .= refined_functions.g.data
# new_ctx.u.data .= refined_functions.u.data
# new_ctx.p1.data .= refined_functions.p1.data
# new_ctx.p2.data .= refined_functions.p2.data
# new_ctx.du.data .= refined_functions.du.data
# new_ctx.dp1.data .= refined_functions.dp1.data
# new_ctx.dp2.data .= refined_functions.dp2.data
#
# return new_ctx, fs_new
#end
# minimal Dörfler marking
function mark(ctx::L1L2TVContext; theta=0.5)
n = ncells(ctx.mesh)
esttotal = sum(ctx.est.data)
......@@ -461,7 +461,8 @@ function primal_energy(ctx::L1L2TVContext)
ctx.beta / 2 * norm(ctx.S(u, nablau))^2 +
ctx.lambda * huber(norm(nablau), ctx.gamma2)
end
return integrate(ctx.mesh, integrand; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
return integrate(ctx.mesh, integrand; ctx.g, ctx.u,
nablau = nabla(ctx.u), ctx.tdata)
end
norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
......@@ -482,7 +483,8 @@ function norm_residual(ctx::L1L2TVContext)
ctx.lambda * nablau
return norm(p1part)^2 + norm(p2part)^2
end
ppart2 = integrate(ctx.mesh, integrand; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata)
ppart2 = integrate(ctx.mesh, integrand; ctx.g, ctx.u,
nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata)
return sqrt(upart2 + ppart2)
end
......@@ -694,7 +696,7 @@ function optflow(ctx)
m = 2
#mesh = init_grid(imgf0; type=:vertex)
mesh = init_grid(imgf0, 20, 20)
mesh = init_grid(imgf0, 1, 1)
#mesh = init_grid(imgf0)
# optflow specific stuff
......@@ -702,12 +704,12 @@ function optflow(ctx)
f0 = FeFunction(Vg, name="f0")
f1 = FeFunction(Vg, name="f1")
fw = FeFunction(Vg, name="fw")
nablafw = nabla(fw)
T(tdata, u) = tdata * u # tdata = nablafw
S(u, nablau) = nablau
#S(u, nablau) = nablau
S(u, nablau) = u
st = L1L2TVContext("run", mesh, m; T, tdata = nablafw, S,
st = L1L2TVContext("run", mesh, m; T, tdata = nabla(fw), S,
ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta,
ctx.params.gamma1, ctx.params.gamma2)
......@@ -720,7 +722,7 @@ function optflow(ctx)
g_optflow(x; u, f0, fw, nablafw) =
nablafw * u - (fw - f0)
interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw)
interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw = nabla(fw))
end
reproject!()
......@@ -728,15 +730,33 @@ function optflow(ctx)
output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"),
st.g, st.u, st.p1, st.p2, st.est, f0, f1, fw)
pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd"))
pvd[0] = save_step(0)
for i in 1:10
i = 0
pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd
pvd[i] = save_step(i)
while true
for k in 1:10
i += 1
step!(st)
estimate!(st)
pvd[i] = save_step(i)
println()
end
vtk_save(pvd)
marked_cells = mark(st; theta = 0.5)
println("refining ...")
mesh, fs = refine(mesh, marked_cells;
st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2,
f0, f1, fw)
st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fs.fw), S,
st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2,
fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
f0, f1, fw = (fs.f0, fs.f1, fs.fw)
println("reprojecting ...")
reproject!()
end
end
display(plot(colorflow(to_img(sample(st.u)))))
#CSV.write(joinpath(ctx.outdir, "energies.csv"), df)
......
export init_grid, init_hgrid, save, refine!, cells, vertices,
export init_grid, init_hgrid, save, refine!, refine, cells, vertices,
ndims_domain, ndims_space
using LinearAlgebra: norm
......@@ -286,7 +286,7 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
# extended functions onto newly created cells
extended_fs = map(NamedTuple(fs)) do f
space = FeSpace(extended_mesh, f.space.element, f.space.size)
return FeFunction(space)
return FeFunction(space; f.name)
end
# copy over previous data for unmodified cells
for (f, extended_f) in zip(NamedTuple(fs), extended_fs)
......@@ -302,7 +302,7 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
# retain only non-refined cells
new_fs = map(NamedTuple(extended_fs)) do f
space = FeSpace(new_mesh, f.space.element, f.space.size)
return FeFunction(space)
return FeFunction(space; f.name)
end
retained_cells = setdiff(cells(extended_mesh), removed_cells)
@assert(retained_cells == cells(hmesh))
......@@ -316,6 +316,24 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
return new_fs
end
"refine by creating temporary hierarchical mesh on the fly"
function refine(mesh::Mesh, marked_cells; fs...)
hmesh = HMesh(mesh)
fs_new = refine!(hmesh, Set(marked_cells); fs...)
mesh_new = sub_mesh(hmesh)
return mesh_new, fs_new
end
function geo_tolocal(A, v)
J = jacobian(x -> A * [1 - x[1] - x[2], x[1], x[2]], [0., 0.])
return J \ (v - A[:, 1])
end
function geo_contains(A, v)
J = jacobian(x -> A * [1 - x[1] - x[2], x[1], x[2]], [0., 0.])
λ = J \ (v - A[:, 1])
return all(λ .>= 0) && sum(λ) <= 1
end
#function cell_contains(mesh, cell, v)
# geo = mesh.vertices[:, mesh.cells[:, cell]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment