diff --git a/src/function.jl b/src/function.jl index 31ffbacccbc54a480420aafa6f98ea18209b0ce7..ed4dddf71433172adf6e817b223b3660d7f0590d 100644 --- a/src/function.jl +++ b/src/function.jl @@ -1,6 +1,6 @@ using Statistics: mean using StaticArrays: SA, SArray, MVector, MMatrix -export FeSpace, Mapper, FeFunction, P1, DP0, DP1 +export FeSpace, Mapper, FeFunction, ImageFunction, P1, DP0, DP1 export interpolate!, sample, bind!, evaluate, integrate, nabla # Finite Elements @@ -252,6 +252,20 @@ function evaluate(df::Derivative, x) return SArray{Tuple{df.f.space.size..., length(x)}}(jac) end +# TODO: inherit from some abstract function type +struct ImageFunction{Img} + mesh::Mesh + img::Img + cell::Base.RefValue{Int} +end + +ImageFunction(mesh, img) = + ImageFunction(mesh, img, Ref(1)) + +bind!(f::ImageFunction, cell) = f.cell[] = cell +evaluate(f::ImageFunction, xloc) = + interpolate_bilinear(f.img, elmap(f.mesh, f.cell[])(xloc)) + function sample(f::FeFunction) mesh = f.mapper.mesh diff --git a/src/mesh.jl b/src/mesh.jl index 98a96a6788a695ee68dd8861b975a312431b0dcc..10bcd1d5530ce7de0ad8895576fff134f0e8674d 100644 --- a/src/mesh.jl +++ b/src/mesh.jl @@ -163,6 +163,9 @@ init_grid(img::Array{<:Any, 2}; type=:vertex) = init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) : init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5)) +init_grid(img::Array{<:Any, 2}, m::Int, n::Int = m) = + init_grid(m, n, (0.5, 0.5), size(img) .- (0.5, 0.5)) + # horribly implemented, please don't curse me function bisect!(mesh::HMesh, marked_cells::Set) refined_cells = Pair{Int, NTuple{2, Int}}[] @@ -337,7 +340,17 @@ function save(filename::String, mesh::Mesh, fs...) vtk_save(vtk) end +function diam(mesh, cell) + A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}( + view(mesh.vertices, :, view(mesh.cells, :, cell))) + return max( + norm(A[:, 1] - A[:, 2]), + norm(A[:, 2] - A[:, 3]), + norm(A[:, 3] - A[:, 1])) +end + function elmap(mesh, cell) + # TODO: can be improved A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}( view(mesh.vertices, :, view(mesh.cells, :, cell))) return x -> A * SA[1 - x[1] - x[2], x[1], x[2]] diff --git a/src/operator.jl b/src/operator.jl index 65a481a09801ccd6b70de1465ab0181620bbf02d..e0ae11c548d0839d86d2d0d60e517c5299b59e48 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -3,7 +3,7 @@ using LinearAlgebra: det, dot using StaticArrays: SA, SArray, MArray using ForwardDiff: jacobian -export Poisson, L2Projection, init_point!, assemble +export Poisson, L2Projection, init_point!, assemble, project_img abstract type Operator end @@ -115,3 +115,95 @@ function assemble(space::FeSpace, a, l; params...) A = sparse(I, J, V, ngdofs, ngdofs) return A, b end + + +function project_img(space::FeSpace, img) + d = 2 # domain dimension + mesh = space.mesh + f = ImageFunction(mesh, img) + opparams = (; f) + + nrdims = prod(space.size) + nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions) + + a(xloc, u, du, v, dv; f) = dot(u, v) + l(xloc, v, dv; f) = dot(f, v) + + # composite midpoint quadrature on lagrange point lattice + function quadrature(p) + k = Iterators.filter(x -> sum(x) == p, + Iterators.product((0:p for _ in 1:d+1)...)) |> collect + + weights = [1 / length(k) for _ in axes(k, 1)] + points = [x[i] / p for i in 1:2, x in k] + + return weights::Vector{Float64}, points::Matrix{Float64} + end + + I = Float64[] + J = Float64[] + V = Float64[] + b = zeros(ndofs(space)) + gdof = LinearIndices((nrdims, ndofs(space))) + + # mesh cells + for cell in cells(mesh) + foreach(f -> bind!(f, cell), opparams) + # cell map is assumed to be constant per cell + delmap = jacobian(elmap(mesh, cell), SA[0., 0.]) + delmapinv = inv(delmap) + intel = abs(det(delmap)) + + p = ceil(Int, diam(mesh, cell)) + qw, qx = quadrature(p) + nqpts = length(qw) # number of quadrature points + + qphi = zeros(nrdims, nrdims, nldofs, nqpts) + dqphi = zeros(nrdims, d, nrdims, nldofs, nqpts) + for r in 1:nrdims + for k in axes(qx, 2) + qphi[r, r, :, k] .= evaluate_basis(space.element, qx[:, k]) + dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(qx[:, k]))) + end + end + + # quadrature points + for k in axes(qx, 2) + xhat = SVector{d}(qx[:, k]) + x = elmap(mesh, cell)(xhat) + opvalues = map(f -> evaluate(f, xhat), opparams) + + # local test-function dofs + for jdim in 1:nrdims, ldofj in 1:nldofs + gdofj = space.dofmap[jdim, ldofj, cell] + + phij = SArray{Tuple{space.size...}}(qphi[:, jdim, ldofj, k]) + dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, jdim, ldofj, k] * delmapinv) + + lv = qw[k] * l(x, phij, dphij; opvalues...) * intel + b[gdofj] += lv + + # local trial-function dofs + for idim in 1:nrdims, ldofi in 1:nldofs + gdofi = space.dofmap[idim, ldofi, cell] + + phii = SArray{Tuple{space.size...}}(qphi[:, idim, ldofi, k]) + dphii = SArray{Tuple{space.size..., d}}(dqphi[:, :, idim, ldofi, k] * delmapinv) + + av = qw[k] * a(x, phii, dphii, phij, dphij; opvalues...) * intel + push!(I, gdofi) + push!(J, gdofj) + push!(V, av) + end + end + end + end + + ngdofs = ndofs(space) + A = sparse(I, J, V, ngdofs, ngdofs) + + u = FeFunction(space) + u.data .= A \ b + + return u +end diff --git a/src/run.jl b/src/run.jl index a6ec678a3c34d122cdc17168860ebcd75383db95..39a4a51a774bb6d902dc49defcede6641bbd036e 100644 --- a/src/run.jl +++ b/src/run.jl @@ -282,7 +282,8 @@ norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) function denoise(img; name, params...) m = 1 - mesh = init_grid(img; type=:vertex) + #mesh = init_grid(img; type=:vertex) + mesh = init_grid(img, 5, 5) T(tdata, u) = u S(u, nablau) = u @@ -313,15 +314,20 @@ function denoise(img; name, params...) println("primal energy: $(primal_energy(ctx))") norm_step = sqrt(norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)) - norm_step <= 1e-3 && break + norm_step = sqrt(norm_l2(ctx.du)^2) + norm_step <= 1e-2 && break end marked_cells = mark(ctx; theta = 0.5) #println(marked_cells) println("refining ...") ctx = refine(ctx, marked_cells) test_mesh(ctx.mesh) + + gnew = project_img(ctx.g.space, img) + ctx.g.data .= gnew.data #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) - k >= 200 && break + + k >= 50 && break end vtk_save(pvd) return ctx