Skip to content
Snippets Groups Projects
Commit 286238f6 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

merge rhs assembly with operator assembly

als huge performance win due to rather obvious loop reordering
parent f3c21d63
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ using LinearAlgebra: det, dot ...@@ -3,7 +3,7 @@ using LinearAlgebra: det, dot
using StaticArrays: SA, SArray, MArray using StaticArrays: SA, SArray, MArray
using ForwardDiff: jacobian using ForwardDiff: jacobian
export Poisson, L2Projection, init_point!, assemble, assemble_rhs export Poisson, L2Projection, init_point!, assemble
abstract type Operator end abstract type Operator end
...@@ -42,10 +42,12 @@ function elmap(mesh, cell, x) ...@@ -42,10 +42,12 @@ function elmap(mesh, cell, x)
return A * SA[1 - x[1] - x[2], x[1], x[2]] return A * SA[1 - x[1] - x[2], x[1], x[2]]
end end
assemble(op::Operator) = assemble( assemble(op::Operator) = assemble(op.space,
op.space, (x...; y...) -> a(op, x...; y...); params(op)...) (x...; y...) -> a(op, x...; y...),
(x...; y...) -> l(op, x...; y...);
params(op)...)
function assemble(space::FeSpace, a; params...) function assemble(space::FeSpace, a, l; params...)
mesh = space.mesh mesh = space.mesh
opparams = NamedTuple(params) opparams = NamedTuple(params)
...@@ -58,51 +60,57 @@ function assemble(space::FeSpace, a; params...) ...@@ -58,51 +60,57 @@ function assemble(space::FeSpace, a; params...)
nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions) nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions)
nqpts = length(qw) # number of quadrature points nqpts = length(qw) # number of quadrature points
qphi_ = zeros(nrdims, nqpts, nrdims, nldofs) qphi_ = zeros(nrdims, nrdims, nldofs, nqpts)
dqphi_ = zeros(nrdims, d, nqpts, nrdims, nldofs) dqphi_ = zeros(nrdims, d, nrdims, nldofs, nqpts)
for r in 1:nrdims for r in 1:nrdims
for k in axes(qx, 2) for k in axes(qx, 2)
qphi_[r, k, r, :] .= evaluate_basis(space.element, qx[:, k]) qphi_[r, r, :, k] .= evaluate_basis(space.element, qx[:, k])
dqphi_[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k])) dqphi_[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k]))
end end
end end
qphi = SArray{Tuple{nrdims, nqpts, nrdims, nldofs}}(qphi_) qphi = SArray{Tuple{nrdims, nrdims, nldofs, nqpts}}(qphi_)
dqphi = SArray{Tuple{nrdims, d, nqpts, nrdims, nldofs}}(dqphi_) dqphi = SArray{Tuple{nrdims, d, nrdims, nldofs, nqpts}}(dqphi_)
I = Float64[] I = Float64[]
J = Float64[] J = Float64[]
V = Float64[] V = Float64[]
b = zeros(space.ndofs)
gdof = LinearIndices((nrdims, space.ndofs)) gdof = LinearIndices((nrdims, space.ndofs))
# mesh cells
for cell in axes(mesh.cells, 2) for cell in axes(mesh.cells, 2)
for f in opparams foreach(f -> bind!(f, cell), opparams)
bind!(f, cell) # cell map is assumed to be constant per cell
end delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::SArray # constant on element delmapinv = inv(delmap)
delmapinv = inv(delmap) # constant on element
intel = abs(det(delmap)) intel = abs(det(delmap))
# local dofs # quadrature points
for idim in 1:nrdims, ldofi in 1:nldofs for k in axes(qx, 2)
gdofi = space.dofmap[idim, ldofi, cell] xhat = qx[:, k]::SArray
opvalues = map(f -> evaluate(f, xhat), opparams)
# local test-function dofs
for jdim in 1:nrdims, ldofj in 1:nldofs for jdim in 1:nrdims, ldofj in 1:nldofs
gdofj = space.dofmap[jdim, ldofj, cell] gdofj = space.dofmap[jdim, ldofj, cell]
# quadrature points phij = SArray{Tuple{space.size...}}(qphi[:, jdim, ldofj, k])
for k in axes(qx, 2) dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, jdim, ldofj, k] * delmapinv)
xhat = qx[:, k]::SArray
opvalues = map(f -> evaluate(f, xhat), opparams) lv = qw[k] * l(xhat, phij, dphij; opvalues...) * intel
b[gdofj] += lv
phii = SArray{Tuple{space.size...}}(qphi[:, k, idim, ldofi]) # local trial-function dofs
dphii = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, idim, ldofi] * delmapinv) for idim in 1:nrdims, ldofi in 1:nldofs
phij = SArray{Tuple{space.size...}}(qphi[:, k, jdim, ldofj]) gdofi = space.dofmap[idim, ldofi, cell]
dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, jdim, ldofj] * delmapinv)
gdofv = qw[k] * a(xhat, phii, dphii, phij, dphij; opvalues...) * intel phii = SArray{Tuple{space.size...}}(qphi[:, idim, ldofi, k])
dphii = SArray{Tuple{space.size..., d}}(dqphi[:, :, idim, ldofi, k] * delmapinv)
av = qw[k] * a(xhat, phii, dphii, phij, dphij; opvalues...) * intel
push!(I, gdofi) push!(I, gdofi)
push!(J, gdofj) push!(J, gdofj)
push!(V, gdofv) push!(V, av)
end end
end end
end end
...@@ -110,68 +118,5 @@ function assemble(space::FeSpace, a; params...) ...@@ -110,68 +118,5 @@ function assemble(space::FeSpace, a; params...)
ngdofs = space.ndofs ngdofs = space.ndofs
A = sparse(I, J, V, ngdofs, ngdofs) A = sparse(I, J, V, ngdofs, ngdofs)
return A return A, b
end end
assemble_rhs(op::Operator) = assemble_rhs(
op.space, (x...; y...) -> l(op, x...; y...); params(op)...)
function assemble_rhs(space::FeSpace, l; params...)
mesh = space.mesh
opparams = NamedTuple(params)
# precompute basis at quadrature points
qw, qx = quadrature()
d = size(qx, 1) # domain dimension
nrdims = prod(space.size)
nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions)
nqpts = length(qw) # number of quadrature points
qphi_ = zeros(nrdims, nqpts, nrdims, nldofs)
dqphi_ = zeros(nrdims, d, nqpts, nrdims, nldofs)
for r in 1:nrdims
for k in axes(qx, 2)
qphi_[r, k, r, :] .= evaluate_basis(space.element, qx[:, k])
dqphi_[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k]))
end
end
qphi = SArray{Tuple{nrdims, nqpts, nrdims, nldofs}}(qphi_)
dqphi = SArray{Tuple{nrdims, d, nqpts, nrdims, nldofs}}(dqphi_)
b = zeros(space.ndofs)
gdof = LinearIndices((nrdims, space.ndofs))
for cell in axes(mesh.cells, 2)
for f in opparams
bind!(f, cell)
end
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::SArray # constant on element
delmapinv = inv(delmap) # constant on element
intel = abs(det(delmap))
# local dofs
for jdim in 1:nrdims, ldofj in 1:nldofs
gdofj = space.dofmap[jdim, ldofj, cell]
# quadrature points
for k in axes(qx, 2)
xhat = qx[:, k]
opvalues = map(f -> evaluate(f, xhat), opparams)
phij = SArray{Tuple{space.size...}}(qphi[:, k, jdim, ldofj])
dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, jdim, ldofj] * delmapinv)
gdofv = qw[k] * l(xhat, phij, dphij; opvalues...) * intel
b[gdofj] += gdofv
end
end
end
return b
end
...@@ -94,7 +94,7 @@ function step!(ctx::L1L2TVContext) ...@@ -94,7 +94,7 @@ function step!(ctx::L1L2TVContext)
return a1 + a2 + aB return a1 + a2 + aB
end end
function du_l(x, phi, nablaphi; g, u, nablau, tdata) function du_l(x, phi, nablaphi; g, u, nablau, p1, p2, tdata)
aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) + aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) +
beta * dot(S(u, nablau), S(phi, nablaphi)) beta * dot(S(u, nablau), S(phi, nablaphi))
m1 = max(gamma1, norm(T(tdata, u) - g)) m1 = max(gamma1, norm(T(tdata, u) - g))
...@@ -108,8 +108,7 @@ function step!(ctx::L1L2TVContext) ...@@ -108,8 +108,7 @@ function step!(ctx::L1L2TVContext)
# solve du # solve du
print("assemble ... ") print("assemble ... ")
A = assemble(ctx.du.space, du_a; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata) A, b = assemble(ctx.du.space, du_a, du_l; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata)
b = assemble_rhs(ctx.du.space, du_l; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
print("solve ... ") print("solve ... ")
ctx.du.data .= A \ b ctx.du.data .= A \ b
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment