Skip to content
Snippets Groups Projects
Commit 8297254a authored by Stephan Hilb's avatar Stephan Hilb
Browse files

add inpaint example

parent 7976b7f2
No related branches found
No related tags found
No related merge requests found
using LinearAlgebra: I, det, dot, norm, normalize using LinearAlgebra: I, det, dot, norm, normalize
using SparseArrays: sparse using SparseArrays: sparse, ishermitian
using Statistics: mean using Statistics: mean
using Colors: Gray using Colors: Gray
...@@ -322,6 +322,8 @@ function step!(st::L1L2TVState) ...@@ -322,6 +322,8 @@ function step!(st::L1L2TVState)
print("solve ... ") print("solve ... ")
A = Au + Ap A = Au + Ap
b = bu + bp b = bu + bp
#println(norm(A - A'))
#println(ishermitian(A))
#A, b = assemble(st.du.space, du_a, du_l; #A, b = assemble(st.du.space, du_a, du_l;
# st.g, st.u, nablau = nabla(st.u), st.p1, st.p2, st.tdata) # st.g, st.u, nablau = nabla(st.u), st.p1, st.p2, st.tdata)
st.du.data .= A \ b st.du.data .= A \ b
...@@ -806,7 +808,7 @@ function denoise(ctx) ...@@ -806,7 +808,7 @@ function denoise(ctx)
project_image! = project_l2_lagrange! project_image! = project_l2_lagrange!
eps_newton = 1e-5 # cauchy criterion for inner newton loop eps_newton = 1e-5 # cauchy criterion for inner newton loop
n_refine = 6 n_refine = 5
# convert to cartesian coordinates # convert to cartesian coordinates
g_arr = from_img(ctx.params.g_arr) g_arr = from_img(ctx.params.g_arr)
...@@ -891,7 +893,7 @@ function experiment_denoise(ctx) ...@@ -891,7 +893,7 @@ function experiment_denoise(ctx)
df = DataFrame() df = DataFrame()
denoise(Util.Context(ctx; name = "test", df, denoise(Util.Context(ctx; name = "test", df,
g_arr, mesh, g_arr, mesh,
alpha1 = 0., alpha2 = 30., lambda = 1., beta = 1e-5, alpha1 = 0., alpha2 = 50., lambda = 1., beta = 1e-5,
gamma1 = 1e-4, gamma2 = 1e-4, gamma1 = 1e-4, gamma2 = 1e-4,
eps_newton = 1e-5, adaptive = true, eps_newton = 1e-5, adaptive = true,
)) ))
...@@ -1246,46 +1248,119 @@ function experiment_approximation(ctx) ...@@ -1246,46 +1248,119 @@ function experiment_approximation(ctx)
)) ))
end end
function inpaint(img, imgmask; params...) # TODO: deduplicate, cf. optflow()
size(img) == size(imgmask) || function inpaint(ctx)
throw(ArgumentError("non-matching dimensions")) # expect ctx.params.g_arr
m = 1 project_image! = project_qi_lagrange!
img = from_img(img) # coord flip n_refine = 5
imgmask = from_img(imgmask) # coord flip
mesh = init_grid(img; type=:vertex) # convert to cartesian coordinates
g_arr = from_img(ctx.params.g_arr)
mask_arr = from_img(ctx.params.mask_arr)
mesh = init_grid(g_arr, floor.(Int, size(g_arr) ./ 2^(n_refine / 2))...)
mesh_area = area(mesh)
# inpaint specific stuff # inpaint specific stuff
Vg = FeSpace(mesh, P1(), (1,)) Vg = FeSpace(mesh, P1(), (1,))
mask = FeFunction(Vg, name = "mask") _mask = FeFunction(Vg, name = "mask")
T(tdata, u) = isone(tdata[begin]) ? u : zero(u) T(tdata, u) = abs(tdata[begin] - 1.) < 1e-8 ? u : zero(u)
T(::typeof(adjoint), tdata, v) = T(tdata, v)
S(u, nablau) = u S(u, nablau) = u
st = L1L2TVState{m}(mesh; T, tdata = mask, S, params...) st = L1L2TVState{1}(mesh;
T, tdata = _mask, S,
ctx.params.alpha1, ctx.params.alpha2,
ctx.params.lambda, ctx.params.beta,
ctx.params.gamma1, ctx.params.gamma2)
function interpolate_image_data!()
println("interpolate image data ...")
project_image!(st.g, g_arr)
project_image!(st.tdata, mask_arr)
end
save_step(i) =
output(st, joinpath(ctx.outdir, "output_$(lpad(i, 5, '0')).vtu"),
st.g, st.u, st.p1, st.p2, st.est)
# FIXME: currently dual grid only pvd = paraview_collection(joinpath(ctx.outdir, "output.pvd")) do pvd
interpolate!(mask, x -> imgmask[round.(Int, x)...])
#interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1)
interpolate!(st.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] : 0.)
m = (size(img) .- 1) ./ 2 .+ 1
interpolate!(st.g, x -> norm(x .- m) < norm(m) / 3)
save_inpaint(i) = interpolate_image_data!()
output(st, "output/$(st.name)_$(lpad(i, 5, '0')).vtu", pvd[0] = save_step(0)
st.g, st.u, st.p1, st.p2, st.est, mask)
pvd = paraview_collection("output/$(st.name).pvd") i = 0
pvd[0] = save_inpaint(0) k_newton = 0
for i in 1:3 k_refine = 0
while true
# interior newton
k_newton += 1
step!(st) step!(st)
estimate_pd!(st) norm_step_ = norm_step(st) / sqrt(mesh_area)
pvd[i] = save_inpaint(i) println("norm_step = $norm_step_")
println()
# interior newton stop criterion
norm_step_ > ctx.params.eps_newton && k_newton < 30 && continue
k_newton = 0
# plot
i += 1
display(plot(grayclamp.(to_img(sample(st.u)))))
estimate_res!(st)
pvd[i] = save_step(i)
#break
# refinement stop criterion
k_refine += 1
k_refine > n_refine && break
println("refine ...")
#estimate_res!(st)
marked_cells = Set(mark(st; theta = 0.5))
# manually mark all cell within inpainting domain, since the
# estimator is not reliable there
for cell in cells(mesh)
bind!(st.tdata, cell)
maskv = evaluate(st.tdata, SA[1/3, 1/3])
if abs(maskv[begin] - 1.) > 1e-8
push!(marked_cells, cell)
end end
end
#marked_cells = Set(axes(mesh.cells, 2))
mesh, fs = refine(mesh, marked_cells;
st.est, st.g, st.u, st.p1, st.p2, st.du, st.dp1, st.dp2, st.tdata)
st = L1L2TVState(st; mesh, fs.tdata,
fs.est, fs.g, fs.u, fs.p1, fs.p2, fs.du, fs.dp1, fs.dp2)
interpolate_image_data!()
end
end
#CSV.write(joinpath(ctx.outdir, "energies.csv"), df)
u_sampled = sample(st.u)
saveimg(joinpath(ctx.outdir, "g.png"), to_img(g_arr))
saveimg(joinpath(ctx.outdir, "output.png"), grayclamp.(to_img(u_sampled)))
savedata(joinpath(ctx.outdir, "data.tex");
ctx.params.eps_newton, n_refine,
st.alpha1, st.alpha2, st.lambda, st.beta, st.gamma1, st.gamma2,
width=size(u_sampled, 1), height=size(u_sampled, 2))
return st return st
end end
function experiment_inpaint(ctx)
g_arr = loadimg(joinpath(ctx.indir, "input.png"))
mask_arr = loadimg(joinpath(ctx.indir, "mask.png"))
mesh = init_grid(g_arr;)
df = DataFrame()
inpaint(Util.Context(ctx; name = "test", df,
g_arr, mask_arr, mesh,
alpha1 = 0., alpha2 = 50., lambda = 1., beta = 1e-5,
gamma1 = 1e-4, gamma2 = 1e-4,
eps_newton = 1e-4, adaptive = true,
))
end
function optflow(ctx) function optflow(ctx)
size(ctx.params.imgf0) == size(ctx.params.imgf1) || size(ctx.params.imgf0) == size(ctx.params.imgf1) ||
throw(ArgumentError("non-matching image sizes")) throw(ArgumentError("non-matching image sizes"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment