From e398949b8036f120d31fb376e75818e7c8b35e34 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Thu, 22 Jul 2021 21:17:34 +0200
Subject: [PATCH] implement image projection

---
 src/function.jl | 16 ++++++++-
 src/mesh.jl     | 13 +++++++
 src/operator.jl | 94 ++++++++++++++++++++++++++++++++++++++++++++++++-
 src/run.jl      | 12 +++++--
 4 files changed, 130 insertions(+), 5 deletions(-)

diff --git a/src/function.jl b/src/function.jl
index 31ffbac..ed4dddf 100644
--- a/src/function.jl
+++ b/src/function.jl
@@ -1,6 +1,6 @@
 using Statistics: mean
 using StaticArrays: SA, SArray, MVector, MMatrix
-export FeSpace, Mapper, FeFunction, P1, DP0, DP1
+export FeSpace, Mapper, FeFunction, ImageFunction, P1, DP0, DP1
 export interpolate!, sample, bind!, evaluate, integrate, nabla
 
 # Finite Elements
@@ -252,6 +252,20 @@ function evaluate(df::Derivative, x)
     return SArray{Tuple{df.f.space.size..., length(x)}}(jac)
 end
 
+# TODO: inherit from some abstract function type
+struct ImageFunction{Img}
+    mesh::Mesh
+    img::Img
+    cell::Base.RefValue{Int}
+end
+
+ImageFunction(mesh, img) =
+    ImageFunction(mesh, img, Ref(1))
+
+bind!(f::ImageFunction, cell) = f.cell[] = cell
+evaluate(f::ImageFunction, xloc) =
+    interpolate_bilinear(f.img, elmap(f.mesh, f.cell[])(xloc))
+
 
 function sample(f::FeFunction)
     mesh = f.mapper.mesh
diff --git a/src/mesh.jl b/src/mesh.jl
index 98a96a6..10bcd1d 100644
--- a/src/mesh.jl
+++ b/src/mesh.jl
@@ -163,6 +163,9 @@ init_grid(img::Array{<:Any, 2}; type=:vertex) =
 	init_grid(size(img, 1) - 1, size(img, 2) - 1, (1.0, 1.0), size(img)) :
 	init_grid(size(img, 1), size(img, 2), (0.5, 0.5), size(img) .- (0.5, 0.5))
 
+init_grid(img::Array{<:Any, 2}, m::Int, n::Int = m) =
+    init_grid(m, n, (0.5, 0.5), size(img) .- (0.5, 0.5))
+
 # horribly implemented, please don't curse me
 function bisect!(mesh::HMesh, marked_cells::Set)
     refined_cells = Pair{Int, NTuple{2, Int}}[]
@@ -337,7 +340,17 @@ function save(filename::String, mesh::Mesh, fs...)
     vtk_save(vtk)
 end
 
+function diam(mesh, cell)
+    A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
+	view(mesh.vertices, :, view(mesh.cells, :, cell)))
+    return max(
+	norm(A[:, 1] - A[:, 2]),
+	norm(A[:, 2] - A[:, 3]),
+	norm(A[:, 3] - A[:, 1]))
+end
+
 function elmap(mesh, cell)
+    # TODO: can be improved
     A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
 	view(mesh.vertices, :, view(mesh.cells, :, cell)))
     return x -> A * SA[1 - x[1] - x[2], x[1], x[2]]
diff --git a/src/operator.jl b/src/operator.jl
index 65a481a..e0ae11c 100644
--- a/src/operator.jl
+++ b/src/operator.jl
@@ -3,7 +3,7 @@ using LinearAlgebra: det, dot
 using StaticArrays: SA, SArray, MArray
 using ForwardDiff: jacobian
 
-export Poisson, L2Projection, init_point!, assemble
+export Poisson, L2Projection, init_point!, assemble, project_img
 
 abstract type Operator end
 
@@ -115,3 +115,95 @@ function assemble(space::FeSpace, a, l; params...)
     A = sparse(I, J, V, ngdofs, ngdofs)
     return A, b
 end
+
+
+function project_img(space::FeSpace, img)
+    d = 2 # domain dimension
+    mesh = space.mesh
+    f = ImageFunction(mesh, img)
+    opparams = (; f)
+
+    nrdims = prod(space.size)
+    nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions)
+
+    a(xloc, u, du, v, dv; f) = dot(u, v)
+    l(xloc, v, dv; f) = dot(f, v)
+
+    # composite midpoint quadrature on lagrange point lattice
+    function quadrature(p)
+	k = Iterators.filter(x -> sum(x) == p,
+	    Iterators.product((0:p for _ in 1:d+1)...)) |> collect
+
+	weights = [1 / length(k) for _ in axes(k, 1)]
+	points = [x[i] / p for i in 1:2, x in k]
+
+	return weights::Vector{Float64}, points::Matrix{Float64}
+    end
+
+    I = Float64[]
+    J = Float64[]
+    V = Float64[]
+    b = zeros(ndofs(space))
+    gdof = LinearIndices((nrdims, ndofs(space)))
+
+    # mesh cells
+    for cell in cells(mesh)
+	foreach(f -> bind!(f, cell), opparams)
+	# cell map is assumed to be constant per cell
+	delmap = jacobian(elmap(mesh, cell), SA[0., 0.])
+	delmapinv = inv(delmap)
+	intel = abs(det(delmap))
+
+	p = ceil(Int, diam(mesh, cell))
+	qw, qx = quadrature(p)
+	nqpts = length(qw) # number of quadrature points
+
+	qphi = zeros(nrdims, nrdims, nldofs, nqpts)
+	dqphi = zeros(nrdims, d, nrdims, nldofs, nqpts)
+	for r in 1:nrdims
+	    for k in axes(qx, 2)
+		qphi[r, r, :, k] .= evaluate_basis(space.element, qx[:, k])
+		dqphi[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), SVector{d}(qx[:, k])))
+	    end
+	end
+
+	# quadrature points
+	for k in axes(qx, 2)
+	    xhat = SVector{d}(qx[:, k])
+	    x = elmap(mesh, cell)(xhat)
+	    opvalues = map(f -> evaluate(f, xhat), opparams)
+
+	    # local test-function dofs
+	    for jdim in 1:nrdims, ldofj in 1:nldofs
+		gdofj = space.dofmap[jdim, ldofj, cell]
+
+		phij = SArray{Tuple{space.size...}}(qphi[:, jdim, ldofj, k])
+		dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, jdim, ldofj, k] * delmapinv)
+
+		lv = qw[k] * l(x, phij, dphij; opvalues...) * intel
+		b[gdofj] += lv
+
+		# local trial-function dofs
+		for idim in 1:nrdims, ldofi in 1:nldofs
+		    gdofi = space.dofmap[idim, ldofi, cell]
+
+		    phii = SArray{Tuple{space.size...}}(qphi[:, idim, ldofi, k])
+		    dphii = SArray{Tuple{space.size..., d}}(dqphi[:, :, idim, ldofi, k] * delmapinv)
+
+		    av = qw[k] * a(x, phii, dphii, phij, dphij; opvalues...) * intel
+		    push!(I, gdofi)
+		    push!(J, gdofj)
+		    push!(V, av)
+		end
+	    end
+	end
+    end
+
+    ngdofs = ndofs(space)
+    A = sparse(I, J, V, ngdofs, ngdofs)
+
+    u = FeFunction(space)
+    u.data .= A \ b
+
+    return u
+end
diff --git a/src/run.jl b/src/run.jl
index a6ec678..39a4a51 100644
--- a/src/run.jl
+++ b/src/run.jl
@@ -282,7 +282,8 @@ norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
 
 function denoise(img; name, params...)
     m = 1
-    mesh = init_grid(img; type=:vertex)
+    #mesh = init_grid(img; type=:vertex)
+    mesh = init_grid(img, 5, 5)
 
     T(tdata, u) = u
     S(u, nablau) = u
@@ -313,15 +314,20 @@ function denoise(img; name, params...)
 	    println("primal energy: $(primal_energy(ctx))")
 
 	    norm_step = sqrt(norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2))
-            norm_step <= 1e-3 && break
+	    norm_step = sqrt(norm_l2(ctx.du)^2)
+            norm_step <= 1e-2 && break
 	end
 	marked_cells = mark(ctx; theta = 0.5)
 	#println(marked_cells)
 	println("refining ...")
 	ctx = refine(ctx, marked_cells)
 	test_mesh(ctx.mesh)
+
+	gnew = project_img(ctx.g.space, img)
+	ctx.g.data .= gnew.data
 	#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
-	k >= 200 && break
+
+	k >= 50 && break
     end
     vtk_save(pvd)
     return ctx
-- 
GitLab