diff --git a/src/function.jl b/src/function.jl index 192f9b098fc0084d8089e7cd79afdc21a66d2787..554792b35740e40699c4cd2696d8b1d2017420f3 100644 --- a/src/function.jl +++ b/src/function.jl @@ -78,7 +78,7 @@ end # evaluate at local point @inline function evaluate(space::FeSpace, ldofs, xloc) bv = evaluate_basis(space.element, xloc) - ldofs_ = SArray{Tuple{prod(space.size),ndofs(space.element)}}(ldofs) + ldofs_ = SArray{Tuple{prod(space.size), ndofs(space.element)}}(ldofs) v = ldofs_ * bv return SArray{Tuple{space.size...}}(v) end @@ -105,7 +105,8 @@ end Base.show(io::IO, ::MIME"text/plain", f::FeFunction) = print("$(nameof(typeof(f))), size $(f.space.size) with $(length(f.data)) dofs") -interpolate!(dst::FeFunction, expr::Function; params...) = interpolate!(dst, dst.space.element, expr; params...) +interpolate!(dst::FeFunction, expr::Function; params...) = + interpolate!(dst, dst.space.element, expr; params...) myvec(x) = vec(x) @@ -121,7 +122,7 @@ function interpolate!(dst::FeFunction, ::P1, expr::Function; params...) end for eldof in axes(mesh.cells, 1) xid = mesh.cells[eldof, cell] - x = mesh.vertices[:, xid] + x = SArray{Tuple{ndims_domain(mesh)}}(mesh.vertices[:, xid]) xloc = SA[0. 1. 0.; 0. 0. 1.][:, eldof] opvalues = map(f -> evaluate(f, xloc), params) @@ -141,7 +142,7 @@ function interpolate!(dst::FeFunction, ::DP0, expr::Function; params...) bind!(f, cell) end vertices = mesh.vertices[:, mesh.cells[:, cell]] - centroid = reshape(mean(vertices, dims = 2), 2) + centroid = SArray{Tuple{ndims_domain(mesh)}}(mean(vertices, dims = 2)) lcentroid = SA[1/3, 1/3] opvalues = map(f -> evaluate(f, lcentroid), params) @@ -162,6 +163,11 @@ end # evaluate at local point (needs bind! call before) evaluate(f::FeFunction, x) = evaluate(f.space, f.ldata, x) +# allow any non-function to act as a constant function +bind!(c, cell) = nothing +evaluate(c, xloc) = c + +# TODO: inherit from some abstract function type struct Derivative{F} f::F end diff --git a/src/image.jl b/src/image.jl index 969f555a7ec62b5c528f922c2952a985dd483222..ab14b3745eb8bc4de25dac19857e39f4c0c8380f 100644 --- a/src/image.jl +++ b/src/image.jl @@ -11,7 +11,7 @@ function interpolate_bilinear(img, x) cornerbool = Bool.(Tuple(idx)) λ = ifelse.(cornerbool, x .- x0, x1 .- x) corner = ifelse.(cornerbool, x1, x0) - val += prod(λ) * eval_neumann(img, CartesianIndex(corner)) + val += prod(λ) * eval_neumann(img, corner) end return val end diff --git a/src/mesh.jl b/src/mesh.jl index 8a039803fb1f10493d50c17cec69539fe89faecf..b5da4610cc5217d69355f0c91235c2fe6978dd1e 100644 --- a/src/mesh.jl +++ b/src/mesh.jl @@ -35,7 +35,7 @@ function init_grid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.)) return Mesh(vertices, cells) end -init_grid(img::Array{<:Any, 2}, type=:vertex) = +init_grid(img::Array{<:Any, 2}; type=:vertex) = type == :vertex ? init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) : init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5)) diff --git a/src/run.jl b/src/run.jl index dc9a9118aea6403e037aafb65036062a9f1d6684..54d220149ce0ed1f216fcfc786ac55dc17a7119c 100644 --- a/src/run.jl +++ b/src/run.jl @@ -1,8 +1,274 @@ -export myrun +export myrun, denoise, inpaint, optflow using LinearAlgebra: norm +struct L1L2TVContext{M,Ttype,Stype} + name::String + mesh::M + d::Int # = ndims_domain(mesh) + m::Int + + T::Ttype + tdata + S::Stype + + alpha1::Float64 + alpha2::Float64 + beta::Float64 + lambda::Float64 + gamma1::Float64 + gamma2::Float64 + + g::FeFunction + u::FeFunction + p1::FeFunction + p2::FeFunction + du::FeFunction + dp1::FeFunction + dp2::FeFunction + nablau + nabladu +end + +function L1L2TVContext(name, mesh, m; T, tdata, S, + alpha1, alpha2, beta, lambda, gamma1, gamma2) + d = ndims_domain(mesh) + + Vg = FeSpace(mesh, P1(), (1,)) + Vu = FeSpace(mesh, P1(), (m,)) + Vp1 = FeSpace(mesh, DP0(), (1,)) + Vp2 = FeSpace(mesh, DP1(), (m, d)) + + g = FeFunction(Vg, name="g") + u = FeFunction(Vu, name="u") + p1 = FeFunction(Vp1, name="p1") + p2 = FeFunction(Vp2, name="p2") + du = FeFunction(Vu) + dp1 = FeFunction(Vp1) + dp2 = FeFunction(Vp2) + nablau = nabla(u) + nabladu = nabla(du) + + g.data .= 0 + u.data .= 0 + p1.data .= 0 + p2.data .= 0 + du.data .= 0 + dp1.data .= 0 + dp2.data .= 0 + + return L1L2TVContext(name, mesh, d, m, T, tdata, S, + alpha1, alpha2, beta, lambda, gamma1, gamma2, + g, u, p1, p2, du, dp1, dp2, nablau, nabladu) +end + +function step!(ctx::L1L2TVContext) + T = ctx.T + S = ctx.S + alpha1 = ctx.alpha1 + alpha2 = ctx.alpha2 + beta = ctx.beta + lambda = ctx.lambda + gamma1 = ctx.gamma1 + gamma2 = ctx.gamma2 + + function du_a(x, du, nabladu, phi, nablaphi; g, u, nablau, p1, p2, tdata) + m1 = max(gamma1, norm(T(tdata, u) - g)) + cond1 = norm(T(tdata, u) - g) > gamma1 ? + dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : + zeros(size(p1)) + a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) - + dot(cond1, T(tdata, phi)) + + m2 = max(gamma2, norm(nablau)) + cond2 = norm(nablau) > gamma2 ? + dot(nablau, nabladu) / norm(nablau)^2 * p2 : + zeros(size(p2)) + a2 = lambda / m2 * dot(nabladu, nablaphi) - + dot(cond2, nablaphi) + + aB = alpha2 * dot(T(tdata, du), T(tdata, phi)) + + beta * dot(S(du, nabladu), S(phi, nablaphi)) + + return a1 + a2 + aB + end + + function du_l(x, phi, nablaphi; g, u, nablau, tdata) + aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) + + beta * dot(S(u, nablau), S(phi, nablaphi)) + m1 = max(gamma1, norm(T(tdata, u) - g)) + p1part = alpha1 / m1 * dot(T(tdata, u) - g, T(tdata, phi)) + m2 = max(gamma2, norm(nablau)) + p2part = lambda / m2 * dot(nablau, nablaphi) + gpart = alpha2 * dot(g, T(tdata, phi)) + + return -aB - p1part - p2part + gpart + end + + # solve du + print("assemble ... ") + A = assemble(ctx.du.space, du_a; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata) + b = assemble_rhs(ctx.du.space, du_l; ctx.g, ctx.u, ctx.nablau, ctx.tdata) + print("solve ... ") + ctx.du.data .= A \ b + + # solve dp1 + function dp1_update(x; g, u, p1, du, tdata) + m1 = max(gamma1, norm(T(tdata, u) - g)) + cond = norm(T(tdata, u) - g) > gamma1 ? + dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : + zeros(size(p1)) + return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond + end + interpolate!(ctx.dp1, dp1_update; ctx.g, ctx.u, ctx.p1, ctx.du, ctx.tdata) + + # solve dp2 + function dp2_update(x; u, nablau, p2, du, nabladu) + m2 = max(gamma2, norm(nablau)) + cond = norm(nablau) > gamma2 ? + dot(nablau, nabladu) / norm(nablau)^2 * p2 : + zeros(size(p2)) + return -p2 + lambda / m2 * (nablau + nabladu) - cond + end + interpolate!(ctx.dp2, dp2_update; ctx.u, ctx.nablau, ctx.p2, ctx.du, ctx.nabladu) + + # newton update + ctx.u.data .+= ctx.du.data + ctx.p1.data .+= ctx.dp1.data + ctx.p2.data .+= ctx.dp2.data + + # reproject p1 + function p1_project!(p1, alpha1) + p1.space.element::DP0 + p1.data .= clamp.(p1.data, -alpha1, alpha1) + end + p1_project!(ctx.p1, ctx.alpha1) + # reproject p2 + function p2_project!(p2, lambda) + p2.space.element::DP1 + p2d = reshape(p2.data, prod(p2.space.size), :) # no copy + for i in axes(p2d, 2) + p2in = norm(p2d[:, i]) + if p2in > lambda + p2d[:, i] .*= lambda ./ p2in + end + end + end + p2_project!(ctx.p2, ctx.lambda) +end + +function save(ctx::L1L2TVContext, filename, fs...) + print("save ... ") + vtk = vtk_mesh(filename, ctx.mesh) + vtk_append!(vtk, fs...) + vtk_save(vtk) + return vtk +end + +function denoise(img; name, params...) + m = 1 + mesh = init_grid(img; type=:vertex) + + T(tdata, u) = u + S(u, nablau) = u + + ctx = L1L2TVContext(name, mesh, m; T, tdata = nothing, S, params...) + + interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) + + save_denoise(i) = + save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", + ctx.g, ctx.u, ctx.p1, ctx.p2) + + pvd = paraview_collection("$(ctx.name).pvd") + pvd[0] = save_denoise(0) + for i in 1:10 + step!(ctx) + pvd[i] = save_denoise(i) + println() + end +end + +function inpaint(img, imgmask; name, params...) + size(img) == size(imgmask) || + throw(ArgumentError("non-matching dimensions")) + + m = 1 + mesh = init_grid(img; type=:vertex) + + # inpaint specific stuff + Vg = FeSpace(mesh, P1(), (1,)) + mask = FeFunction(Vg, name="mask") + + T(tdata, u) = iszero(tdata) ? zero(u) : u + S(u, nablau) = u + + ctx = L1L2TVContext(name, mesh, m; T, tdata = mask, S, params...) + + # FIXME: currently dual grid only + interpolate!(mask, x -> imgmask[round.(Int, x)...]) + #interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1) + interpolate!(ctx.g, x -> imgmask[round.(Int, x)...] ? img[round.(Int, x)...] : 0.) + + save_inpaint(i) = + save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", + ctx.g, ctx.u, ctx.p1, ctx.p2, mask) + + pvd = paraview_collection("$(ctx.name).pvd") + pvd[0] = save_inpaint(0) + for i in 1:10 + step!(ctx) + pvd[i] = save_inpaint(i) + println() + end +end + +function optflow(imgf0, imgf1; name, params...) + size(imgf0) == size(imgf1) || + throw(ArgumentError("non-matching dimensions")) + + m = 2 + mesh = init_grid(imgf0; type=:vertex) + + # optflow specific stuff + Vg = FeSpace(mesh, P1(), (1,)) + f0 = FeFunction(Vg, name="f0") + f1 = FeFunction(Vg, name="f1") + fw = FeFunction(Vg, name="fw") + nablafw = nabla(fw) + + T(tdata, u) = tdata * u + S(u, nablau) = nablau + + ctx = L1L2TVContext(name, mesh, m; T, tdata = nablafw, S, params...) + + # FIXME: currently dual grid only + interpolate!(f0, x -> imgf0[round.(Int, x)...]) + interpolate!(f1, x -> imgf1[round.(Int, x)...]) + fw.data .= f1.data + + g_optflow(x; u, f0, fw, nablafw) = + nablafw * u - (fw - f0) + interpolate!(ctx.g, g_optflow; ctx.u, f0, fw, nablafw) + + save_optflow(i) = + save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", + ctx.g, ctx.u, ctx.p1, ctx.p2, f0, f1, fw) + + pvd = paraview_collection("$(ctx.name).pvd") + pvd[0] = save_optflow(0) + for i in 1:10 + step!(ctx) + pvd[i] = save_optflow(i) + println() + end +end + + + + + function myrun() name = "test" @@ -17,11 +283,6 @@ function myrun() # inpainting mask = FeFunction(Vg, name="mask") - # optflow - f0 = FeFunction(Vg, name="f0") - f1 = FeFunction(Vg, name="f1") - fw = FeFunction(Vg, name="fw") - nablafw = nabla(fw) g = FeFunction(Vg, name="g") u = FeFunction(Vu, name="u") @@ -49,7 +310,6 @@ function myrun() gamma2 = 1e-3 interpolate!(g, x -> norm(x - SA[0.5, 0.5]) < 0.3) - interpolate!(mask, x -> abs(x[2] - 0.5) > 0.1) interpolate!(f0, x -> x[1]) interpolate!(f1, x -> x[1] - 0.01)