Skip to content
Snippets Groups Projects
Commit 65804157 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

get refinement working somewhat

parent 6d2a2f59
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,8 @@ using OpticalFlowUtils ...@@ -9,7 +9,8 @@ using OpticalFlowUtils
using WriteVTK: paraview_collection using WriteVTK: paraview_collection
using SemiSmoothNewton using SemiSmoothNewton
using SemiSmoothNewton: project_img!, project! using SemiSmoothNewton: HMesh, ncells, refine
using SemiSmoothNewton: project_img!, project_img2!, project!
using SemiSmoothNewton: vtk_mesh, vtk_append!, vtk_save using SemiSmoothNewton: vtk_mesh, vtk_append!, vtk_save
include("util.jl") include("util.jl")
...@@ -63,8 +64,6 @@ struct L1L2TVContext{M, Ttype, Stype} ...@@ -63,8 +64,6 @@ struct L1L2TVContext{M, Ttype, Stype}
du::FeFunction du::FeFunction
dp1::FeFunction dp1::FeFunction
dp2::FeFunction dp2::FeFunction
nablau
nabladu
end end
function L1L2TVContext(name, mesh, m; T, tdata, S, function L1L2TVContext(name, mesh, m; T, tdata, S,
...@@ -82,11 +81,9 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, ...@@ -82,11 +81,9 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
u = FeFunction(Vu, name="u") u = FeFunction(Vu, name="u")
p1 = FeFunction(Vp1, name="p1") p1 = FeFunction(Vp1, name="p1")
p2 = FeFunction(Vp2, name="p2") p2 = FeFunction(Vp2, name="p2")
du = FeFunction(Vu) du = FeFunction(Vu; name = "du")
dp1 = FeFunction(Vp1) dp1 = FeFunction(Vp1; name = "dp1")
dp2 = FeFunction(Vp2) dp2 = FeFunction(Vp2; name = "dp2")
nablau = nabla(u)
nabladu = nabla(du)
est.data .= 0 est.data .= 0
g.data .= 0 g.data .= 0
...@@ -99,7 +96,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S, ...@@ -99,7 +96,7 @@ function L1L2TVContext(name, mesh, m; T, tdata, S,
return L1L2TVContext(name, mesh, d, m, T, tdata, S, return L1L2TVContext(name, mesh, d, m, T, tdata, S,
alpha1, alpha2, beta, lambda, gamma1, gamma2, alpha1, alpha2, beta, lambda, gamma1, gamma2,
est, g, u, p1, p2, du, dp1, dp2, nablau, nabladu) est, g, u, p1, p2, du, dp1, dp2)
end end
function p1_project!(p1, alpha1) function p1_project!(p1, alpha1)
...@@ -166,7 +163,7 @@ function step!(ctx::L1L2TVContext) ...@@ -166,7 +163,7 @@ function step!(ctx::L1L2TVContext)
# solve du # solve du
print("assemble ... ") print("assemble ... ")
A, b = assemble(ctx.du.space, du_a, du_l; A, b = assemble(ctx.du.space, du_a, du_l;
ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata) ctx.g, ctx.u, nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata)
print("solve ... ") print("solve ... ")
ctx.du.data .= A \ b ctx.du.data .= A \ b
...@@ -191,7 +188,7 @@ function step!(ctx::L1L2TVContext) ...@@ -191,7 +188,7 @@ function step!(ctx::L1L2TVContext)
return -p2 + lambda / m2 * (nablau + nabladu) - cond return -p2 + lambda / m2 * (nablau + nabladu) - cond
end end
interpolate!(ctx.dp2, dp2_update; interpolate!(ctx.dp2, dp2_update;
ctx.u, ctx.nablau, ctx.p2, ctx.du, ctx.nabladu) ctx.u, nablau = nabla(ctx.u), ctx.p2, ctx.du, nabladu = nabla(ctx.du))
# newton update # newton update
theta = 1. theta = 1.
...@@ -388,46 +385,49 @@ function estimate!(ctx::L1L2TVContext) ...@@ -388,46 +385,49 @@ function estimate!(ctx::L1L2TVContext)
end end
w = FeFunction(ctx.u.space) w = FeFunction(ctx.u.space)
nablaw = nabla(w)
solve_primal!(w, ctx) solve_primal!(w, ctx)
project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.nablau, w, nablaw, ctx.tdata) project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2,
end nablau = nabla(ctx.u), w, nablaw = nabla(w), ctx.tdata)
end
function refine(ctx::L1L2TVContext, marked_cells; fs_...)
fs = NamedTuple(fs_) # TODO: deprecate in favor of refine(mesh, marked_cells; fs...)
#function refine(ctx::L1L2TVContext, marked_cells; fs_...)
hmesh = HMesh(ctx.mesh) # fs = NamedTuple(fs_)
refined_functions = refine!(hmesh, Set(marked_cells); #
ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2, # hmesh = HMesh(ctx.mesh)
fs...) # refined_functions = refine!(hmesh, Set(marked_cells);
new_mesh = refined_functions.u.space.mesh # ctx.est, ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.du, ctx.dp1, ctx.dp2,
# fs...)
new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S, # new_mesh = refined_functions.u.space.mesh
ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2) #
# # TODO: tdata needs to be recreated for refinement
fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs)) # new_ctx = L1L2TVContext(ctx.name, new_mesh, ctx.m; ctx.T, ctx.tdata, ctx.S,
# ctx.alpha1, ctx.alpha2, ctx.beta, ctx.lambda, ctx.gamma1, ctx.gamma2)
@assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap) #
@assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap) # fs_new = NamedTuple(x[1] => refined_functions[x[1]] for x in pairs(fs))
@assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap) #
@assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap) # @assert(new_ctx.est.space.dofmap == refined_functions.est.space.dofmap)
@assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap) # @assert(new_ctx.g.space.dofmap == refined_functions.g.space.dofmap)
@assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap) # @assert(new_ctx.u.space.dofmap == refined_functions.u.space.dofmap)
@assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap) # @assert(new_ctx.p1.space.dofmap == refined_functions.p1.space.dofmap)
@assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap) # @assert(new_ctx.p2.space.dofmap == refined_functions.p2.space.dofmap)
# @assert(new_ctx.du.space.dofmap == refined_functions.du.space.dofmap)
new_ctx.est.data .= refined_functions.est.data # @assert(new_ctx.dp1.space.dofmap == refined_functions.dp1.space.dofmap)
new_ctx.g.data .= refined_functions.g.data # @assert(new_ctx.dp2.space.dofmap == refined_functions.dp2.space.dofmap)
new_ctx.u.data .= refined_functions.u.data #
new_ctx.p1.data .= refined_functions.p1.data # new_ctx.est.data .= refined_functions.est.data
new_ctx.p2.data .= refined_functions.p2.data # new_ctx.g.data .= refined_functions.g.data
new_ctx.du.data .= refined_functions.du.data # new_ctx.u.data .= refined_functions.u.data
new_ctx.dp1.data .= refined_functions.dp1.data # new_ctx.p1.data .= refined_functions.p1.data
new_ctx.dp2.data .= refined_functions.dp2.data # new_ctx.p2.data .= refined_functions.p2.data
# new_ctx.du.data .= refined_functions.du.data
return new_ctx, fs_new # new_ctx.dp1.data .= refined_functions.dp1.data
end # new_ctx.dp2.data .= refined_functions.dp2.data
#
# return new_ctx, fs_new
#end
# minimal Dörfler marking
function mark(ctx::L1L2TVContext; theta=0.5) function mark(ctx::L1L2TVContext; theta=0.5)
n = ncells(ctx.mesh) n = ncells(ctx.mesh)
esttotal = sum(ctx.est.data) esttotal = sum(ctx.est.data)
...@@ -461,7 +461,8 @@ function primal_energy(ctx::L1L2TVContext) ...@@ -461,7 +461,8 @@ function primal_energy(ctx::L1L2TVContext)
ctx.beta / 2 * norm(ctx.S(u, nablau))^2 + ctx.beta / 2 * norm(ctx.S(u, nablau))^2 +
ctx.lambda * huber(norm(nablau), ctx.gamma2) ctx.lambda * huber(norm(nablau), ctx.gamma2)
end end
return integrate(ctx.mesh, integrand; ctx.g, ctx.u, ctx.nablau, ctx.tdata) return integrate(ctx.mesh, integrand; ctx.g, ctx.u,
nablau = nabla(ctx.u), ctx.tdata)
end end
norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
...@@ -482,7 +483,8 @@ function norm_residual(ctx::L1L2TVContext) ...@@ -482,7 +483,8 @@ function norm_residual(ctx::L1L2TVContext)
ctx.lambda * nablau ctx.lambda * nablau
return norm(p1part)^2 + norm(p2part)^2 return norm(p1part)^2 + norm(p2part)^2
end end
ppart2 = integrate(ctx.mesh, integrand; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata) ppart2 = integrate(ctx.mesh, integrand; ctx.g, ctx.u,
nablau = nabla(ctx.u), ctx.p1, ctx.p2, ctx.tdata)
return sqrt(upart2 + ppart2) return sqrt(upart2 + ppart2)
end end
...@@ -694,7 +696,7 @@ function optflow(ctx) ...@@ -694,7 +696,7 @@ function optflow(ctx)
m = 2 m = 2
#mesh = init_grid(imgf0; type=:vertex) #mesh = init_grid(imgf0; type=:vertex)
mesh = init_grid(imgf0, 20, 20) mesh = init_grid(imgf0, 1, 1)
#mesh = init_grid(imgf0) #mesh = init_grid(imgf0)
# optflow specific stuff # optflow specific stuff
...@@ -702,12 +704,12 @@ function optflow(ctx) ...@@ -702,12 +704,12 @@ function optflow(ctx)
f0 = FeFunction(Vg, name="f0") f0 = FeFunction(Vg, name="f0")
f1 = FeFunction(Vg, name="f1") f1 = FeFunction(Vg, name="f1")
fw = FeFunction(Vg, name="fw") fw = FeFunction(Vg, name="fw")
nablafw = nabla(fw)
T(tdata, u) = tdata * u # tdata = nablafw T(tdata, u) = tdata * u # tdata = nablafw
S(u, nablau) = nablau #S(u, nablau) = nablau
S(u, nablau) = u
st = L1L2TVContext("run", mesh, m; T, tdata = nablafw, S, st = L1L2TVContext("run", mesh, m; T, tdata = nabla(fw), S,
ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta, ctx.params.alpha1, ctx.params.alpha2, ctx.params.lambda, ctx.params.beta,
ctx.params.gamma1, ctx.params.gamma2) ctx.params.gamma1, ctx.params.gamma2)
...@@ -720,7 +722,7 @@ function optflow(ctx) ...@@ -720,7 +722,7 @@ function optflow(ctx)
g_optflow(x; u, f0, fw, nablafw) = g_optflow(x; u, f0, fw, nablafw) =
nablafw * u - (fw - f0) nablafw * u - (fw - f0)
interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw) interpolate!(st.g, g_optflow; st.u, f0, fw, nablafw = nabla(fw))
end end
reproject!() reproject!()
...@@ -728,15 +730,33 @@ function optflow(ctx) ...@@ -728,15 +730,33 @@ function optflow(ctx)
output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"), output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"),
st.g, st.u, st.p1, st.p2, st.est, f0, f1, fw) st.g, st.u, st.p1, st.p2, st.est, f0, f1, fw)
pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) i = 0
pvd[0] = save_step(0) pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd
for i in 1:10 pvd[i] = save_step(i)
while true
for k in 1:10
i += 1
step!(st) step!(st)
estimate!(st) estimate!(st)
pvd[i] = save_step(i) pvd[i] = save_step(i)
println() println()
end end
vtk_save(pvd) marked_cells = mark(st; theta = 0.5)
println("refining ...")
mesh, fs = refine(mesh, marked_cells;
st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2,
f0, f1, fw)
st = L1L2TVContext("run", mesh, st.d, st.m, T, nabla(fs.fw), S,
st.alpha1, st.alpha2, st.beta, st.lambda, st.gamma1, st.gamma2,
fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
f0, f1, fw = (fs.f0, fs.f1, fs.fw)
println("reprojecting ...")
reproject!()
end
end
display(plot(colorflow(to_img(sample(st.u))))) display(plot(colorflow(to_img(sample(st.u)))))
#CSV.write(joinpath(ctx.outdir, "energies.csv"), df) #CSV.write(joinpath(ctx.outdir, "energies.csv"), df)
......
export init_grid, init_hgrid, save, refine!, cells, vertices, export init_grid, init_hgrid, save, refine!, refine, cells, vertices,
ndims_domain, ndims_space ndims_domain, ndims_space
using LinearAlgebra: norm using LinearAlgebra: norm
...@@ -286,7 +286,7 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...) ...@@ -286,7 +286,7 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
# extended functions onto newly created cells # extended functions onto newly created cells
extended_fs = map(NamedTuple(fs)) do f extended_fs = map(NamedTuple(fs)) do f
space = FeSpace(extended_mesh, f.space.element, f.space.size) space = FeSpace(extended_mesh, f.space.element, f.space.size)
return FeFunction(space) return FeFunction(space; f.name)
end end
# copy over previous data for unmodified cells # copy over previous data for unmodified cells
for (f, extended_f) in zip(NamedTuple(fs), extended_fs) for (f, extended_f) in zip(NamedTuple(fs), extended_fs)
...@@ -302,7 +302,7 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...) ...@@ -302,7 +302,7 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
# retain only non-refined cells # retain only non-refined cells
new_fs = map(NamedTuple(extended_fs)) do f new_fs = map(NamedTuple(extended_fs)) do f
space = FeSpace(new_mesh, f.space.element, f.space.size) space = FeSpace(new_mesh, f.space.element, f.space.size)
return FeFunction(space) return FeFunction(space; f.name)
end end
retained_cells = setdiff(cells(extended_mesh), removed_cells) retained_cells = setdiff(cells(extended_mesh), removed_cells)
@assert(retained_cells == cells(hmesh)) @assert(retained_cells == cells(hmesh))
...@@ -316,6 +316,24 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...) ...@@ -316,6 +316,24 @@ function refine!(hmesh::HMesh, marked_cells::Set; fs...)
return new_fs return new_fs
end end
"refine by creating temporary hierarchical mesh on the fly"
function refine(mesh::Mesh, marked_cells; fs...)
hmesh = HMesh(mesh)
fs_new = refine!(hmesh, Set(marked_cells); fs...)
mesh_new = sub_mesh(hmesh)
return mesh_new, fs_new
end
function geo_tolocal(A, v)
J = jacobian(x -> A * [1 - x[1] - x[2], x[1], x[2]], [0., 0.])
return J \ (v - A[:, 1])
end
function geo_contains(A, v)
J = jacobian(x -> A * [1 - x[1] - x[2], x[1], x[2]], [0., 0.])
λ = J \ (v - A[:, 1])
return all(λ .>= 0) && sum(λ) <= 1
end
#function cell_contains(mesh, cell, v) #function cell_contains(mesh, cell, v)
# geo = mesh.vertices[:, mesh.cells[:, cell]] # geo = mesh.vertices[:, mesh.cells[:, cell]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment