Skip to content
Snippets Groups Projects
Commit 286238f6 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

merge rhs assembly with operator assembly

als huge performance win due to rather obvious loop reordering
parent f3c21d63
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ using LinearAlgebra: det, dot
using StaticArrays: SA, SArray, MArray
using ForwardDiff: jacobian
export Poisson, L2Projection, init_point!, assemble, assemble_rhs
export Poisson, L2Projection, init_point!, assemble
abstract type Operator end
......@@ -42,10 +42,12 @@ function elmap(mesh, cell, x)
return A * SA[1 - x[1] - x[2], x[1], x[2]]
end
assemble(op::Operator) = assemble(
op.space, (x...; y...) -> a(op, x...; y...); params(op)...)
assemble(op::Operator) = assemble(op.space,
(x...; y...) -> a(op, x...; y...),
(x...; y...) -> l(op, x...; y...);
params(op)...)
function assemble(space::FeSpace, a; params...)
function assemble(space::FeSpace, a, l; params...)
mesh = space.mesh
opparams = NamedTuple(params)
......@@ -58,51 +60,57 @@ function assemble(space::FeSpace, a; params...)
nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions)
nqpts = length(qw) # number of quadrature points
qphi_ = zeros(nrdims, nqpts, nrdims, nldofs)
dqphi_ = zeros(nrdims, d, nqpts, nrdims, nldofs)
qphi_ = zeros(nrdims, nrdims, nldofs, nqpts)
dqphi_ = zeros(nrdims, d, nrdims, nldofs, nqpts)
for r in 1:nrdims
for k in axes(qx, 2)
qphi_[r, k, r, :] .= evaluate_basis(space.element, qx[:, k])
dqphi_[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k]))
qphi_[r, r, :, k] .= evaluate_basis(space.element, qx[:, k])
dqphi_[r, :, r, :, k] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k]))
end
end
qphi = SArray{Tuple{nrdims, nqpts, nrdims, nldofs}}(qphi_)
dqphi = SArray{Tuple{nrdims, d, nqpts, nrdims, nldofs}}(dqphi_)
qphi = SArray{Tuple{nrdims, nrdims, nldofs, nqpts}}(qphi_)
dqphi = SArray{Tuple{nrdims, d, nrdims, nldofs, nqpts}}(dqphi_)
I = Float64[]
J = Float64[]
V = Float64[]
b = zeros(space.ndofs)
gdof = LinearIndices((nrdims, space.ndofs))
# mesh cells
for cell in axes(mesh.cells, 2)
for f in opparams
bind!(f, cell)
end
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::SArray # constant on element
delmapinv = inv(delmap) # constant on element
foreach(f -> bind!(f, cell), opparams)
# cell map is assumed to be constant per cell
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])
delmapinv = inv(delmap)
intel = abs(det(delmap))
# local dofs
for idim in 1:nrdims, ldofi in 1:nldofs
gdofi = space.dofmap[idim, ldofi, cell]
for jdim in 1:nrdims, ldofj in 1:nldofs
gdofj = space.dofmap[jdim, ldofj, cell]
# quadrature points
for k in axes(qx, 2)
xhat = qx[:, k]::SArray
opvalues = map(f -> evaluate(f, xhat), opparams)
phii = SArray{Tuple{space.size...}}(qphi[:, k, idim, ldofi])
dphii = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, idim, ldofi] * delmapinv)
phij = SArray{Tuple{space.size...}}(qphi[:, k, jdim, ldofj])
dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, jdim, ldofj] * delmapinv)
# local test-function dofs
for jdim in 1:nrdims, ldofj in 1:nldofs
gdofj = space.dofmap[jdim, ldofj, cell]
phij = SArray{Tuple{space.size...}}(qphi[:, jdim, ldofj, k])
dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, jdim, ldofj, k] * delmapinv)
lv = qw[k] * l(xhat, phij, dphij; opvalues...) * intel
b[gdofj] += lv
# local trial-function dofs
for idim in 1:nrdims, ldofi in 1:nldofs
gdofi = space.dofmap[idim, ldofi, cell]
gdofv = qw[k] * a(xhat, phii, dphii, phij, dphij; opvalues...) * intel
phii = SArray{Tuple{space.size...}}(qphi[:, idim, ldofi, k])
dphii = SArray{Tuple{space.size..., d}}(dqphi[:, :, idim, ldofi, k] * delmapinv)
av = qw[k] * a(xhat, phii, dphii, phij, dphij; opvalues...) * intel
push!(I, gdofi)
push!(J, gdofj)
push!(V, gdofv)
push!(V, av)
end
end
end
......@@ -110,68 +118,5 @@ function assemble(space::FeSpace, a; params...)
ngdofs = space.ndofs
A = sparse(I, J, V, ngdofs, ngdofs)
return A
return A, b
end
assemble_rhs(op::Operator) = assemble_rhs(
op.space, (x...; y...) -> l(op, x...; y...); params(op)...)
function assemble_rhs(space::FeSpace, l; params...)
mesh = space.mesh
opparams = NamedTuple(params)
# precompute basis at quadrature points
qw, qx = quadrature()
d = size(qx, 1) # domain dimension
nrdims = prod(space.size)
nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions)
nqpts = length(qw) # number of quadrature points
qphi_ = zeros(nrdims, nqpts, nrdims, nldofs)
dqphi_ = zeros(nrdims, d, nqpts, nrdims, nldofs)
for r in 1:nrdims
for k in axes(qx, 2)
qphi_[r, k, r, :] .= evaluate_basis(space.element, qx[:, k])
dqphi_[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k]))
end
end
qphi = SArray{Tuple{nrdims, nqpts, nrdims, nldofs}}(qphi_)
dqphi = SArray{Tuple{nrdims, d, nqpts, nrdims, nldofs}}(dqphi_)
b = zeros(space.ndofs)
gdof = LinearIndices((nrdims, space.ndofs))
for cell in axes(mesh.cells, 2)
for f in opparams
bind!(f, cell)
end
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::SArray # constant on element
delmapinv = inv(delmap) # constant on element
intel = abs(det(delmap))
# local dofs
for jdim in 1:nrdims, ldofj in 1:nldofs
gdofj = space.dofmap[jdim, ldofj, cell]
# quadrature points
for k in axes(qx, 2)
xhat = qx[:, k]
opvalues = map(f -> evaluate(f, xhat), opparams)
phij = SArray{Tuple{space.size...}}(qphi[:, k, jdim, ldofj])
dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, jdim, ldofj] * delmapinv)
gdofv = qw[k] * l(xhat, phij, dphij; opvalues...) * intel
b[gdofj] += gdofv
end
end
end
return b
end
......@@ -94,7 +94,7 @@ function step!(ctx::L1L2TVContext)
return a1 + a2 + aB
end
function du_l(x, phi, nablaphi; g, u, nablau, tdata)
function du_l(x, phi, nablaphi; g, u, nablau, p1, p2, tdata)
aB = alpha2 * dot(T(tdata, u), T(tdata, phi)) +
beta * dot(S(u, nablau), S(phi, nablaphi))
m1 = max(gamma1, norm(T(tdata, u) - g))
......@@ -108,8 +108,7 @@ function step!(ctx::L1L2TVContext)
# solve du
print("assemble ... ")
A = assemble(ctx.du.space, du_a; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata)
b = assemble_rhs(ctx.du.space, du_l; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
A, b = assemble(ctx.du.space, du_a, du_l; ctx.g, ctx.u, ctx.nablau, ctx.p1, ctx.p2, ctx.tdata)
print("solve ... ")
ctx.du.data .= A \ b
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment