Skip to content
Snippets Groups Projects
Commit 83e09cee authored by Stephan Hilb's avatar Stephan Hilb
Browse files

use static arrays for all per-element operations

parent 6f6fc20a
No related branches found
No related tags found
No related merge requests found
using Statistics: mean
using StaticArrays: SA
using StaticArrays: SA, SArray
export FeSpace, Mapper, FeFunction, P1, DP0, DP1
export interpolate!, sample, bind!, evaluate, nabla
......@@ -8,6 +8,10 @@ struct P1 end
struct DP0 end
struct DP1 end
ndofs(::P1) = 3
ndofs(::DP0) = 1
ndofs(::DP1) = 3
# FIXME: should be static vectors
# evaluate all (1-dim) local basis functions against x
evaluate_basis(::P1, x) = SA[1 - x[1] - x[2], x[1], x[2]]
......@@ -22,7 +26,6 @@ struct FeSpace{M, Fe, S}
element::Fe
dofmap::Array{Int, 3} # (rdim, eldof, cell) -> gdof
ndofs::Int # = maximum(dofmap)
size::S
end
function FeSpace(mesh, el::P1, size_=(1,))
......@@ -33,7 +36,8 @@ function FeSpace(mesh, el::P1, size_=(1,))
for i in CartesianIndices((rdims, 3, ncells))
dofmap[i] = rdims * (mesh.cells[i[2], i[3]] - 1) + i[1]
end
return FeSpace(mesh, el, dofmap, rdims * size(mesh.vertices, 2), size_)
return FeSpace{typeof(mesh), typeof(el), size_}(
mesh, el, dofmap, rdims * size(mesh.vertices, 2))
end
function FeSpace(mesh, el::DP0, size_=(1,))
......@@ -44,7 +48,8 @@ function FeSpace(mesh, el::DP0, size_=(1,))
for i in CartesianIndices((rdims, 1, ncells))
dofmap[i] = rdims * (i[3] - 1) + i[1]
end
return FeSpace(mesh, el, dofmap, rdims * ncells, size_)
return FeSpace{typeof(mesh), typeof(el), size_}(
mesh, el, dofmap, rdims * ncells)
end
function FeSpace(mesh, el::DP1, size_=(1,))
......@@ -55,17 +60,27 @@ function FeSpace(mesh, el::DP1, size_=(1,))
for i in LinearIndices((rdims, 3, ncells))
dofmap[i] = i
end
return FeSpace(mesh, el, dofmap, rdims * 3 * ncells, size_)
return FeSpace{typeof(mesh), typeof(el), size_}(
mesh, el, dofmap, rdims * 3 * ncells)
end
Base.show(io::IO, ::MIME"text/plain", x::FeSpace) =
print("$(nameof(typeof(x))), $(nameof(typeof(x.element))) elements, size $(x.size), $(x.ndofs) dofs")
function Base.getproperty(obj::FeSpace{<:Any, <:Any, S}, sym::Symbol) where S
if sym === :size
return S
else
return getfield(obj, sym)
end
end
# evaluate at local point
function evaluate(space::FeSpace, ldofs, xloc)
@inline function evaluate(space::FeSpace, ldofs, xloc)
bv = evaluate_basis(space.element, xloc)
v = reshape(ldofs, size(space.dofmap)[1:2]) * bv
return reshape(v, space.size)
ldofs_ = SArray{Tuple{prod(space.size),ndofs(space.element)}}(ldofs)
v = ldofs_ * bv
return SArray{Tuple{space.size...}}(v)
end
......@@ -107,7 +122,7 @@ function interpolate!(dst::FeFunction, ::P1, expr::Function; params...)
for eldof in axes(mesh.cells, 1)
xid = mesh.cells[eldof, cell]
x = mesh.vertices[:, xid]
xloc = [0. 1. 0.; 0. 0. 1.][:, eldof]
xloc = SA[0. 1. 0.; 0. 0. 1.][:, eldof]
opvalues = map(f -> evaluate(f, xloc), params)
......@@ -127,7 +142,7 @@ function interpolate!(dst::FeFunction, ::DP0, expr::Function; params...)
end
vertices = mesh.vertices[:, mesh.cells[:, cell]]
centroid = reshape(mean(vertices, dims = 2), 2)
lcentroid = [1/3, 1/3]
lcentroid = SA[1/3, 1/3]
opvalues = map(f -> evaluate(f, lcentroid), params)
......@@ -159,7 +174,7 @@ nabla(f) = Derivative(f)
bind!(df::Derivative, cell) = bind!(df.f, cell)
function evaluate(df::Derivative, x)
jac = jacobian(x -> evaluate(df.f.space, df.f.ldata, x), x)
return reshape(jac, df.f.space.size..., :)
return SArray{Tuple{df.f.space.size..., length(x)}}(jac)
end
......
......@@ -11,6 +11,10 @@ end
Base.show(io::IO, ::MIME"text/plain", f::Mesh) =
print("$(nameof(typeof(f))), $(size(f.cells, 2)) cells")
ndims_domain(::Mesh) = 2
ndims_space(::Mesh) = 2
nvertices_cell(::Mesh) = 3
function init_grid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
r1 = LinRange(v0[1], v1[1], m + 1)
r2 = LinRange(v0[2], v1[2], n + 1)
......
using SparseArrays: sparse
using LinearAlgebra: det, dot
using StaticArrays: SA
using StaticArrays: SA, SArray, MArray
using ForwardDiff: jacobian
export Poisson, L2Projection, init_point!, assemble, assemble_rhs
......@@ -36,8 +36,11 @@ a(op::Poisson, xloc, u, du, v, dv) = dot(du, dv)
quadrature() = SA[1/6, 1/6, 1/6], SA[1/6 4/6 1/6; 1/6 1/6 4/6]
elmap(mesh, cell, x) =
mesh.vertices[:, mesh.cells[:, cell]] * SA[1 - x[1] - x[2], x[1], x[2]]
function elmap(mesh, cell, x)
A = SArray{Tuple{ndims_space(mesh), nvertices_cell(mesh)}}(
view(mesh.vertices, :, view(mesh.cells, :, cell)))
return A * SA[1 - x[1] - x[2], x[1], x[2]]
end
assemble(op::Operator) = assemble(
op.space, (x...; y...) -> a(op, x...; y...); params(op)...)
......@@ -52,23 +55,20 @@ function assemble(space::FeSpace, a; params...)
d = size(qx, 1) # domain dimension
nrdims = prod(space.size)
nldofs = size(space.dofmap, 2) # number of local dofs (not counting range dimensions)
nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions)
nqpts = length(qw) # number of quadrature points
qphi = zeros(nrdims, nqpts, nrdims, nldofs)
dqphi = zeros(nrdims, d, nqpts, nrdims, nldofs)
qphi_ = zeros(nrdims, nqpts, nrdims, nldofs)
dqphi_ = zeros(nrdims, d, nqpts, nrdims, nldofs)
for r in 1:nrdims
for k in axes(qx, 2)
qphi[r, k, r, :] .= evaluate_basis(space.element, qx[:, k])
dqphi[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k])::AbstractArray{Float64,2})
qphi_[r, k, r, :] .= evaluate_basis(space.element, qx[:, k])
dqphi_[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k]))
end
end
qphi = SArray{Tuple{nrdims, nqpts, nrdims, nldofs}}(qphi_)
dqphi = SArray{Tuple{nrdims, d, nqpts, nrdims, nldofs}}(dqphi_)
xhat = zeros(d)
phii = zeros(space.size...)
dphii = zeros(space.size..., d)
phij = zeros(space.size...)
dphij = zeros(space.size..., d)
I = Float64[]
J = Float64[]
V = Float64[]
......@@ -77,7 +77,7 @@ function assemble(space::FeSpace, a; params...)
for f in opparams
bind!(f, cell)
end
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::Array{Float64,2} # constant on element
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::SArray # constant on element
delmapinv = inv(delmap) # constant on element
intel = abs(det(delmap))
......@@ -89,13 +89,13 @@ function assemble(space::FeSpace, a; params...)
# quadrature points
for k in axes(qx, 2)
xhat .= qx[:, k]
xhat = qx[:, k]::SArray
opvalues = map(f -> evaluate(f, xhat), opparams)
phii .= reshape(view(qphi, :, k, idim, ldofi), space.size)
dphii .= reshape(view(dqphi, :, :, k, idim, ldofi) * delmapinv, (space.size..., :))
phij .= reshape(view(qphi, :, k, jdim, ldofj), space.size)
dphij .= reshape(view(dqphi, :, :, k, jdim, ldofj) * delmapinv, (space.size..., :))
phii = SArray{Tuple{space.size...}}(qphi[:, k, idim, ldofi])
dphii = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, idim, ldofi] * delmapinv)
phij = SArray{Tuple{space.size...}}(qphi[:, k, jdim, ldofj])
dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, jdim, ldofj] * delmapinv)
gdofv = qw[k] * a(xhat, phii, dphii, phij, dphij; opvalues...) * intel
......@@ -126,28 +126,27 @@ function assemble_rhs(space::FeSpace, l; params...)
d = size(qx, 1) # domain dimension
nrdims = prod(space.size)
nldofs = size(space.dofmap, 2) # number of local dofs (not counting range dimensions)
nldofs = ndofs(space.element) # number of element dofs (i.e. local dofs not counting range dimensions)
nqpts = length(qw) # number of quadrature points
qphi = zeros(nrdims, nqpts, nrdims, nldofs)
dqphi = zeros(nrdims, d, nqpts, nrdims, nldofs)
qphi_ = zeros(nrdims, nqpts, nrdims, nldofs)
dqphi_ = zeros(nrdims, d, nqpts, nrdims, nldofs)
for r in 1:nrdims
for k in axes(qx, 2)
qphi[r, k, r, :] .= evaluate_basis(space.element, qx[:, k])
dqphi[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k])::AbstractArray{Float64,2})
qphi_[r, k, r, :] .= evaluate_basis(space.element, qx[:, k])
dqphi_[r, :, k, r, :] .= transpose(jacobian(x -> evaluate_basis(space.element, x), qx[:, k]))
end
end
qphi = SArray{Tuple{nrdims, nqpts, nrdims, nldofs}}(qphi_)
dqphi = SArray{Tuple{nrdims, d, nqpts, nrdims, nldofs}}(dqphi_)
xhat = zeros(d)
phij = zeros(space.size...)
dphij = zeros(space.size..., d)
b = zeros(space.ndofs)
gdof = LinearIndices((nrdims, space.ndofs))
for cell in axes(mesh.cells, 2)
for f in opparams
bind!(f, cell)
end
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::Array{Float64,2} # constant on element
delmap = jacobian(x -> elmap(mesh, cell, x), SA[0., 0.])::SArray # constant on element
delmapinv = inv(delmap) # constant on element
intel = abs(det(delmap))
......@@ -157,11 +156,11 @@ function assemble_rhs(space::FeSpace, l; params...)
# quadrature points
for k in axes(qx, 2)
xhat .= qx[:, k]
xhat = qx[:, k]
opvalues = map(f -> evaluate(f, xhat), opparams)
phij .= reshape(qphi[:, k, jdim, ldofj], space.size)
dphij .= reshape(dqphi[:, :, k, jdim, ldofj] * delmapinv, (space.size..., :))
phij = SArray{Tuple{space.size...}}(qphi[:, k, jdim, ldofj])
dphij = SArray{Tuple{space.size..., d}}(dqphi[:, :, k, jdim, ldofj] * delmapinv)
gdofv = qw[k] * l(xhat, phij, dphij; opvalues...) * intel
......
......@@ -6,7 +6,7 @@ using LinearAlgebra: norm
function myrun()
name = "test"
mesh = init_grid(50)
mesh = init_grid(200)
d = 2
m = 1
......@@ -15,7 +15,7 @@ function myrun()
Vp1 = FeSpace(mesh, DP0(), (1,))
Vp2 = FeSpace(mesh, DP1(), (m, d))
g = FeFunction(Vu, name="g")
g = FeFunction(Vg, name="g")
u = FeFunction(Vu, name="u")
p1 = FeFunction(Vp1, name="p1")
p2 = FeFunction(Vp2, name="p2")
......@@ -35,7 +35,7 @@ function myrun()
alpha1 = 0.
alpha2 = 10.
beta = 1e-2
beta = 0.
lambda = 0.01
gamma1 = 1e-3
gamma2 = 1e-3
......@@ -52,7 +52,7 @@ function myrun()
cond = norm(T(u) - g) > gamma1 ?
dot(T(u) - g, T(du)) / norm(T(u) - g)^2 * p1 :
zeros(size(p1))
return -p1 + alpha1 / m1 * (T(u + du) - g) - cond
return -p1 + alpha1 / m1 * (T(u) + T(du) - g) - cond
end
......@@ -122,11 +122,14 @@ function myrun()
pvd = paraview_collection("$(name).pvd")
save_step!(pvd, 0)
for i = 1:6
for i = 1:5
print("newton step $i ...")
# solve du
A = assemble(Vu, du_a; g, u, nablau, p1, p2)
b = assemble_rhs(Vu, du_l; g, u, nablau)
print(" assembled ...")
du.data .= A \ b
println(" solved ...")
# solve dp1, dp2
interpolate!(dp1, dp1_update; g, u, p1, du)
......@@ -144,4 +147,5 @@ function myrun()
save_step!(pvd, i)
end
vtk_save(pvd)
return
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment