Skip to content
Snippets Groups Projects
Commit 37c04181 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

refactor code

parent 2e5565de
No related branches found
No related tags found
No related merge requests found
......@@ -78,7 +78,7 @@ end
# evaluate at local point
@inline function evaluate(space::FeSpace, ldofs, xloc)
bv = evaluate_basis(space.element, xloc)
ldofs_ = SArray{Tuple{prod(space.size),ndofs(space.element)}}(ldofs)
ldofs_ = SArray{Tuple{prod(space.size), ndofs(space.element)}}(ldofs)
v = ldofs_ * bv
return SArray{Tuple{space.size...}}(v)
end
......@@ -105,7 +105,8 @@ end
Base.show(io::IO, ::MIME"text/plain", f::FeFunction) =
print("$(nameof(typeof(f))), size $(f.space.size) with $(length(f.data)) dofs")
interpolate!(dst::FeFunction, expr::Function; params...) = interpolate!(dst, dst.space.element, expr; params...)
interpolate!(dst::FeFunction, expr::Function; params...) =
interpolate!(dst, dst.space.element, expr; params...)
myvec(x) = vec(x)
......@@ -121,7 +122,7 @@ function interpolate!(dst::FeFunction, ::P1, expr::Function; params...)
end
for eldof in axes(mesh.cells, 1)
xid = mesh.cells[eldof, cell]
x = mesh.vertices[:, xid]
x = SArray{Tuple{ndims_domain(mesh)}}(mesh.vertices[:, xid])
xloc = SA[0. 1. 0.; 0. 0. 1.][:, eldof]
opvalues = map(f -> evaluate(f, xloc), params)
......@@ -141,7 +142,7 @@ function interpolate!(dst::FeFunction, ::DP0, expr::Function; params...)
bind!(f, cell)
end
vertices = mesh.vertices[:, mesh.cells[:, cell]]
centroid = reshape(mean(vertices, dims = 2), 2)
centroid = SArray{Tuple{ndims_domain(mesh)}}(mean(vertices, dims = 2))
lcentroid = SA[1/3, 1/3]
opvalues = map(f -> evaluate(f, lcentroid), params)
......@@ -162,6 +163,11 @@ end
# evaluate at local point (needs bind! call before)
evaluate(f::FeFunction, x) = evaluate(f.space, f.ldata, x)
# allow any non-function to act as a constant function
bind!(c, cell) = nothing
evaluate(c, xloc) = c
# TODO: inherit from some abstract function type
struct Derivative{F}
f::F
end
......
......@@ -11,7 +11,7 @@ function interpolate_bilinear(img, x)
cornerbool = Bool.(Tuple(idx))
λ = ifelse.(cornerbool, x .- x0, x1 .- x)
corner = ifelse.(cornerbool, x1, x0)
val += prod(λ) * eval_neumann(img, CartesianIndex(corner))
val += prod(λ) * eval_neumann(img, corner)
end
return val
end
......
......@@ -35,7 +35,7 @@ function init_grid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
return Mesh(vertices, cells)
end
init_grid(img::Array{<:Any, 2}, type=:vertex) =
init_grid(img::Array{<:Any, 2}; type=:vertex) =
type == :vertex ?
init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) :
init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5))
......
export myrun
export myrun, denoise, inpaint, optflow
using LinearAlgebra: norm
struct L1L2TVContext{M,Ttype,Stype}
name::String
mesh::M
d::Int # = ndims_domain(mesh)
m::Int
T::Ttype
tdata
S::Stype
alpha1::Float64
alpha2::Float64
beta::Float64
lambda::Float64
gamma1::Float64
gamma2::Float64
g::FeFunction
u::FeFunction
p1::FeFunction
p2::FeFunction
du::FeFunction
dp1::FeFunction
dp2::FeFunction
nablau
nabladu
end
function L1L2TVContext(name, mesh, m; T, tdata, S,
alpha1, alpha2, beta, lambda, gamma1, gamma2)
d = ndims_domain(mesh)
Vg = FeSpace(mesh, P1(), (1,))
Vu = FeSpace(mesh, P1(), (m,))
Vp1 = FeSpace(mesh, DP0(), (1,))
Vp2 = FeSpace(mesh, DP1(), (m, d))
g = FeFunction(Vg, name="g")
u = FeFunction(Vu, name="u")
p1 = FeFunction(Vp1, name="p1")
p2 = FeFunction(Vp2, name="p2")
du = FeFunction(Vu)
dp1 = FeFunction(Vp1)
dp2 = FeFunction(Vp2)
nablau = nabla(u)
nabladu = nabla(du)
g.data .= 0
u.data .= 0
p1.data .= 0
p2.data .= 0
du.data .= 0
dp1.data .= 0
dp2.data .= 0
return L1L2TVContext(name, mesh, d, m, T, tdata, S,
alpha1, alpha2, beta, lambda, gamma1, gamma2,
g, u, p1, p2, du, dp1, dp2, nablau, nabladu)
end
function step!(ctx::L1L2TVContext)
T = ctx.T
S = ctx.S
alpha1 = ctx.alpha1
alpha2 = ctx.alpha2
beta = ctx.beta
lambda = ctx.lambda
gamma1 = ctx.gamma1
gamma2 = ctx.gamma2
function du_a(x, du, nabladu, phi, nablaphi; g, u, nablau, p1, p2, tdata)
m1 = max(gamma1, norm(T(tdata, u) - g))
cond1 = norm(T(tdata, u) - g) > gamma1 ?
dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
zeros(size(p1))
a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) -
dot(cond1, T(tdata, phi))
m2 = max(gamma2, norm(nablau))
cond2 = norm(nablau) > gamma2 ?
dot(nablau, nabladu) / norm(nablau)^2 * p2 :
zeros(size(p2))
a2 = lambda / m2 * dot(nabladu, nablaphi) -
dot(cond2, nablaphi)
aB = alpha2 * dot(T(tdata, du), T(tdata, phi)) +
beta * dot(S(du, nabladu), S(phi, nablaphi))
return a1 + a2 + aB
end
function du_l(x, phi, nablaphi; g, u, nablau, tdata)
aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) +
beta * dot(S(u, nablau), S(phi, nablaphi))
m1 = max(gamma1, norm(T(tdata, u) - g))
p1part = alpha1 / m1 * dot(T(tdata, u) - g, T(tdata, phi))
m2 = max(gamma2, norm(nablau))
p2part = lambda / m2 * dot(nablau, nablaphi)
gpart = alpha2 * dot(g, T(tdata, phi))
return -aB - p1part - p2part + gpart
end
# solve du
print("assemble ... ")
A = assemble(ctx.du.space, du_a; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata)
b = assemble_rhs(ctx.du.space, du_l; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
print("solve ... ")
ctx.du.data .= A \ b
# solve dp1
function dp1_update(x; g, u, p1, du, tdata)
m1 = max(gamma1, norm(T(tdata, u) - g))
cond = norm(T(tdata, u) - g) > gamma1 ?
dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 :
zeros(size(p1))
return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond
end
interpolate!(ctx.dp1, dp1_update; ctx.g, ctx.u, ctx.p1, ctx.du, ctx.tdata)
# solve dp2
function dp2_update(x; u, nablau, p2, du, nabladu)
m2 = max(gamma2, norm(nablau))
cond = norm(nablau) > gamma2 ?
dot(nablau, nabladu) / norm(nablau)^2 * p2 :
zeros(size(p2))
return -p2 + lambda / m2 * (nablau + nabladu) - cond
end
interpolate!(ctx.dp2, dp2_update; ctx.u, ctx.nablau, ctx.p2, ctx.du, ctx.nabladu)
# newton update
ctx.u.data .+= ctx.du.data
ctx.p1.data .+= ctx.dp1.data
ctx.p2.data .+= ctx.dp2.data
# reproject p1
function p1_project!(p1, alpha1)
p1.space.element::DP0
p1.data .= clamp.(p1.data, -alpha1, alpha1)
end
p1_project!(ctx.p1, ctx.alpha1)
# reproject p2
function p2_project!(p2, lambda)
p2.space.element::DP1
p2d = reshape(p2.data, prod(p2.space.size), :) # no copy
for i in axes(p2d, 2)
p2in = norm(p2d[:, i])
if p2in > lambda
p2d[:, i] .*= lambda ./ p2in
end
end
end
p2_project!(ctx.p2, ctx.lambda)
end
function save(ctx::L1L2TVContext, filename, fs...)
print("save ... ")
vtk = vtk_mesh(filename, ctx.mesh)
vtk_append!(vtk, fs...)
vtk_save(vtk)
return vtk
end
function denoise(img; name, params...)
m = 1
mesh = init_grid(img; type=:vertex)
T(tdata, u) = u
S(u, nablau) = u
ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...)
interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
save_denoise(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2)
pvd = paraview_collection("$(ctx.name).pvd")
pvd[0] = save_denoise(0)
for i in 1:10
step!(ctx)
pvd[i] = save_denoise(i)
println()
end
end
function inpaint(img, imgmask; name, params...)
size(img) == size(imgmask) ||
throw(ArgumentError("non-matching dimensions"))
m = 1
mesh = init_grid(img; type=:vertex)
# inpaint specific stuff
Vg = FeSpace(mesh, P1(), (1,))
mask = FeFunction(Vg, name="mask")
T(tdata, u) = iszero(tdata) ? zero(u) : u
S(u, nablau) = u
ctx = L1L2TVContext(name, mesh, m; T, tdata = mask, S, params...)
# FIXME: currently dual grid only
interpolate!(mask, x -> imgmask[round.(Int, x)...])
#interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1)
interpolate!(ctx.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] : 0.)
save_inpaint(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, mask)
pvd = paraview_collection("$(ctx.name).pvd")
pvd[0] = save_inpaint(0)
for i in 1:10
step!(ctx)
pvd[i] = save_inpaint(i)
println()
end
end
function optflow(imgf0, imgf1; name, params...)
size(imgf0) == size(imgf1) ||
throw(ArgumentError("non-matching dimensions"))
m = 2
mesh = init_grid(imgf0; type=:vertex)
# optflow specific stuff
Vg = FeSpace(mesh, P1(), (1,))
f0 = FeFunction(Vg, name="f0")
f1 = FeFunction(Vg, name="f1")
fw = FeFunction(Vg, name="fw")
nablafw = nabla(fw)
T(tdata, u) = tdata * u
S(u, nablau) = nablau
ctx = L1L2TVContext(name, mesh, m; T, tdata = nablafw, S, params...)
# FIXME: currently dual grid only
interpolate!(f0, x -> imgf0[round.(Int, x)...])
interpolate!(f1, x -> imgf1[round.(Int, x)...])
fw.data .= f1.data
g_optflow(x; u, f0, fw, nablafw) =
nablafw * u - (fw - f0)
interpolate!(ctx.g, g_optflow; ctx.u, f0, fw, nablafw)
save_optflow(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, f0, f1, fw)
pvd = paraview_collection("$(ctx.name).pvd")
pvd[0] = save_optflow(0)
for i in 1:10
step!(ctx)
pvd[i] = save_optflow(i)
println()
end
end
function myrun()
name = "test"
......@@ -17,11 +283,6 @@ function myrun()
# inpainting
mask = FeFunction(Vg, name="mask")
# optflow
f0 = FeFunction(Vg, name="f0")
f1 = FeFunction(Vg, name="f1")
fw = FeFunction(Vg, name="fw")
nablafw = nabla(fw)
g = FeFunction(Vg, name="g")
u = FeFunction(Vu, name="u")
......@@ -49,7 +310,6 @@ function myrun()
gamma2 = 1e-3
interpolate!(g, x -> norm(x - SA[0.5, 0.5]) < 0.3)
interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1)
interpolate!(f0, x -> x[1])
interpolate!(f1, x -> x[1] - 0.01)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment