diff --git a/src/function.jl b/src/function.jl index 554792b35740e40699c4cd2696d8b1d2017420f3..cff7100f89264590d8ca3b10b9c3132ec760d535 100644 --- a/src/function.jl +++ b/src/function.jl @@ -1,5 +1,5 @@ using Statistics: mean -using StaticArrays: SA, SArray +using StaticArrays: SA, SArray, MVector export FeSpace, Mapper, FeFunction, P1, DP0, DP1 export interpolate!, sample, bind!, evaluate, nabla @@ -86,19 +86,19 @@ end # dof ordering for vector valued functions: -# (ldof, fdims...) +# (rsize..., eldofs) # Array-valued function -struct FeFunction{Sp} +struct FeFunction{Sp,Ld} space::Sp data::Vector{Float64} # gdof -> data name::String - ldata::Vector{Float64} # ldof -> data + ldata::Ld # ldof -> data end function FeFunction(space; name=string(gensym("f"))) data = Vector{Float64}(undef, space.ndofs) - ldata = Vector{Float64}(undef, prod(size(space.dofmap)[1:2])) + ldata = zero(MVector{prod(space.size) * ndofs(space.element)}) return FeFunction(space, data, name, ldata) end @@ -164,7 +164,7 @@ end evaluate(f::FeFunction, x) = evaluate(f.space, f.ldata, x) # allow any non-function to act as a constant function -bind!(c, cell) = nothing +bind!(c, cell) = c evaluate(c, xloc) = c # TODO: inherit from some abstract function type diff --git a/src/run.jl b/src/run.jl index 54d220149ce0ed1f216fcfc786ac55dc17a7119c..5e22ef98a89c3e6a73384b5d00bf8304b9434a10 100644 --- a/src/run.jl +++ b/src/run.jl @@ -77,14 +77,14 @@ function step!(ctx::L1L2TVContext) m1 = max(gamma1, norm(T(tdata, u) - g)) cond1 = norm(T(tdata, u) - g) > gamma1 ? dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : - zeros(size(p1)) + zero(p1) a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) - dot(cond1, T(tdata, phi)) m2 = max(gamma2, norm(nablau)) cond2 = norm(nablau) > gamma2 ? dot(nablau, nabladu) / norm(nablau)^2 * p2 : - zeros(size(p2)) + zero(p2) a2 = lambda / m2 * dot(nabladu, nablaphi) - dot(cond2, nablaphi) @@ -118,7 +118,7 @@ function step!(ctx::L1L2TVContext) m1 = max(gamma1, norm(T(tdata, u) - g)) cond = norm(T(tdata, u) - g) > gamma1 ? dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : - zeros(size(p1)) + zero(p1) return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond end interpolate!(ctx.dp1, dp1_update; ctx.g, ctx.u, ctx.p1, ctx.du, ctx.tdata) @@ -128,7 +128,7 @@ function step!(ctx::L1L2TVContext) m2 = max(gamma2, norm(nablau)) cond = norm(nablau) > gamma2 ? dot(nablau, nabladu) / norm(nablau)^2 * p2 : - zeros(size(p2)) + zero(p2) return -p2 + lambda / m2 * (nablau + nabladu) - cond end interpolate!(ctx.dp2, dp2_update; ctx.u, ctx.nablau, ctx.p2, ctx.du, ctx.nabladu) @@ -334,7 +334,7 @@ function myrun() m1 = max(gamma1, norm(T(tdata, u) - g)) cond = norm(T(tdata, u) - g) > gamma1 ? dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : - zeros(size(p1)) + zero(p1) return -p1 + alpha1 / m1 * (T(tdata, u) + T(tdata, du) - g) - cond end @@ -343,7 +343,7 @@ function myrun() m2 = max(gamma2, norm(nablau)) cond = norm(nablau) > gamma2 ? dot(nablau, nabladu) / norm(nablau)^2 * p2 : - zeros(size(p2)) + zero(p2) return -p2 + lambda / m2 * (nablau + nabladu) - cond end @@ -367,14 +367,14 @@ function myrun() m1 = max(gamma1, norm(T(tdata, u) - g)) cond1 = norm(T(tdata, u) - g) > gamma1 ? dot(T(tdata, u) - g, T(tdata, du)) / norm(T(tdata, u) - g)^2 * p1 : - zeros(size(p1)) + zero(p1) a1 = alpha1 / m1 * dot(T(tdata, du), T(tdata, phi)) - dot(cond1, T(tdata, phi)) m2 = max(gamma2, norm(nablau)) cond2 = norm(nablau) > gamma2 ? dot(nablau, nabladu) / norm(nablau)^2 * p2 : - zeros(size(p2)) + zero(p2) a2 = lambda / m2 * dot(nabladu, nablaphi) - dot(cond2, nablaphi)