diff --git a/scripts/run_experiments.jl b/scripts/run_experiments.jl index 80e0e1de7a66b16c65b9a78be28a95289850e56e..e6497a41e78a71865a6215aa388148812c4751ac 100644 --- a/scripts/run_experiments.jl +++ b/scripts/run_experiments.jl @@ -424,39 +424,42 @@ huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2 # this computes the primal-dual error indicator which is not really useful # if not computed on a finer mesh than `u` was solved on -function estimate!(ctx::L1L2TVState) +function estimate!(st::L1L2TVState) function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata) alpha1part = - ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) - - dot(ctx.T(tdata, u) - g, p1) + - (iszero(ctx.alpha1) ? 0. : - ctx.gamma1 / (2 * ctx.alpha1) * norm(p1)^2) + st.alpha1 * huber(norm(st.T(tdata, u) - g), st.gamma1) - + dot(st.T(tdata, u) - g, p1) + + (iszero(st.alpha1) ? 0. : + st.gamma1 / (2 * st.alpha1) * norm(p1)^2) lambdapart = - ctx.lambda * huber(norm(nablau), ctx.gamma2) - + st.lambda * huber(norm(nablau), st.gamma2) - dot(nablau, p2) + - (iszero(ctx.lambda) ? 0. : - ctx.gamma2 / (2 * ctx.lambda) * norm(p2)^2) + (iszero(st.lambda) ? 0. : + st.gamma2 / (2 * st.lambda) * norm(p2)^2) # avoid non-negative rounding errors alpha1part = max(0, alpha1part) lambdapart = max(0, lambdapart) bpart = 1 / 2 * ( - ctx.alpha2 * dot(ctx.T(tdata, w - u), ctx.T(tdata, w - u)) + - ctx.beta * dot(ctx.S(w, nablaw) - ctx.S(u, nablau), - ctx.S(w, nablaw) - ctx.S(u, nablau))) - - return alpha1part + lambdapart + bpart + st.alpha2 * dot(st.T(tdata, w - u), st.T(tdata, w - u)) + + st.beta * dot(st.S(w, nablaw) - st.S(u, nablau), + st.S(w, nablaw) - st.S(u, nablau))) + res = alpha1part + lambdapart + bpart + @assert isfinite(res) + return res end - w = FeFunction(ctx.u.space) - solve_primal!(w, ctx) + w = FeFunction(st.u.space) + solve_primal!(w, st) #w.data .= .-w.data # TODO: find better name: is actually a cell-wise integration - project!(ctx.est, estf; ctx.g, ctx.u, ctx.p1, ctx.p2, - nablau = nabla(ctx.u), w, nablaw = nabla(w), ctx.tdata) + project!(st.est, estf; st.g, st.u, st.p1, st.p2, + nablau = nabla(st.u), w, nablaw = nabla(w), st.tdata) + + st.est.data .= sqrt.(st.est.data) end estimate_error(st::L1L2TVState) = - sum(st.est.data) / area(st.mesh) + sqrt(sum(x -> x^2, st.est.data) / area(st.mesh)) # minimal Dörfler marking function mark(ctx::L1L2TVState; theta=0.5)