Skip to content
Snippets Groups Projects
Commit 7daf94ff authored by Stephan Hilb's avatar Stephan Hilb
Browse files

fix energy calculation

parent 754a746b
No related branches found
No related tags found
No related merge requests found
*.vtu
*.pvd
data/*
......@@ -21,6 +21,12 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.11.0"
[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
......@@ -71,12 +77,24 @@ version = "0.8.5"
deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
[[FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
git-tree-sha1 = "256d8e6188f3f1ebfa1a5d17e072a0efafa8c5bf"
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
version = "1.10.1"
[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.7"
[[FixedPointNumbers]]
deps = ["Statistics"]
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.4"
[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
......@@ -193,6 +211,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
......
......@@ -4,6 +4,8 @@ authors = ["Stephan Hilb <stephan@ecshi.net>"]
version = "0.1.0"
[deps]
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
......
......@@ -202,7 +202,7 @@ function integrate(mesh::Mesh, expr; params...)
for k in axes(qx, 2)
xhat = qx[:, k]
x = elmap(mesh, cell)(xhat)
opvalues = map(f -> evaluate(f, x), opparams)
opvalues = map(f -> evaluate(f, xhat), opparams)
val += qw[k] * expr(x; opvalues...) * intel
end
......@@ -211,6 +211,7 @@ function integrate(mesh::Mesh, expr; params...)
end
function bind!(f::FeFunction, cell)
# TODO: make this non-allocating
f.ldata .= vec(f.data[f.space.dofmap[:, :, cell]])
return f
end
......
export myrun, denoise, inpaint, optflow, solve_primal!, estimate!
export myrun, denoise, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
using LinearAlgebra: norm
# avoid world-age-issues by preloading ColorTypes
import ColorTypes
import FileIO
struct L1L2TVContext{M, Ttype, Stype}
name::String
......@@ -179,9 +183,9 @@ function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
u.data .= A \ b
end
function estimate!(ctx::L1L2TVContext)
huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
function estimate!(ctx::L1L2TVContext)
function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * (
huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
......@@ -252,7 +256,7 @@ function mark(ctx::L1L2TVContext; theta=0.5)
end
function save(ctx::L1L2TVContext, filename, fs...)
function output(ctx::L1L2TVContext, filename, fs...)
print("save \"$filename\" ... ")
vtk = vtk_mesh(filename, ctx.mesh)
vtk_append!(vtk, fs...)
......@@ -260,67 +264,18 @@ function save(ctx::L1L2TVContext, filename, fs...)
return vtk
end
function test_refine(mesh = nothing)
m = 1
if isnothing(mesh)
mesh = init_grid(2, 2)
end
T(tdata, u) = u
S(u, nablau) = u
ctx = L1L2TVContext("test", mesh, m; T, tdata = nothing, S,
alpha1=0., alpha2=5., lambda=0.1, beta=1e-3, gamma1=1e-3, gamma2=1e-3)
interpolate!(ctx.g, x -> x[1] + x[2] > 1.01)
mysave(ctx, i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("$(ctx.name).pvd")
for i in 1:10
step!(ctx)
end
pvd[0] = mysave(ctx, 0)
ctx2 = refine(ctx, [1, 2, 3, 4, 5, 6, 7, 8])
#interpolate!(ctx2.g, x -> x[1] + x[2] < 2.01)
#interpolate!(ctx2.g, x -> norm(x .- m) < norm(m) / 3)
#ctx2.est.data .= 0
#ctx2.g.data .= 0
#ctx2.u.data .= 0
#ctx2.p1.data .= 0
#ctx2.p2.data .= 0
#ctx2.du.data .= 0
#ctx2.dp1.data .= 0
#ctx2.dp2.data .= 0
for i in 1:10
step!(ctx2)
end
pvd[1] = mysave(ctx2, 1)
test_mesh(ctx2.mesh)
toimg(arr) = Gray.(clamp.(arr, 0., 1.))
loadimg(x) = reverse(transpose(Float64.(FileIO.load(x))); dims = 2)
saveimg(io, x) = FileIO.save(io, toimg(x))
return ctx, ctx2
for i in 1:5
step!(ctx)
estimate!(ctx)
pvd[i] = mysave(i)
println()
function primal_energy(ctx::L1L2TVContext)
function integrand(x; g, u, nablau, tdata)
return ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
ctx.alpha2 / 2 * norm(ctx.T(tdata, u) - g)^2 +
ctx.beta / 2 * norm(ctx.S(u, nablau))^2 +
ctx.lambda * huber(norm(nablau), ctx.gamma2)
end
estmean = mean(ctx.est.data)
marked_cells = findall(>(2*estmean), ctx.est.data)
println(marked_cells)
ctx = refine(ctx, marked_cells)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
for i in 6:11
step!(ctx)
estimate!(ctx)
pvd[i] = mysave(i)
println()
end
return ctx
return integrate(ctx.mesh, integrand; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
end
norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
......@@ -336,16 +291,17 @@ function denoise(img; name, params...)
interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
m = (size(img) .- 1) ./ 2 .+ 1
interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
save_denoise(ctx, i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("$(ctx.name).pvd")
pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_denoise(ctx, 0)
k = 0
println("primal energy: $(primal_energy(ctx))")
while true
while true
k += 1
......@@ -354,6 +310,7 @@ function denoise(img; name, params...)
pvd[k] = save_denoise(ctx, k)
println()
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))")
norm_step = sqrt(norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2))
norm_step <= 1e-3 && break
......@@ -363,7 +320,7 @@ function denoise(img; name, params...)
println("refining ...")
ctx = refine(ctx, marked_cells)
test_mesh(ctx.mesh)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
#interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
k >= 200 && break
end
vtk_save(pvd)
......@@ -394,10 +351,10 @@ function inpaint(img, imgmask; name, params...)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
save_inpaint(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask)
pvd = paraview_collection("$(ctx.name).pvd")
pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_inpaint(0)
for i in 1:3
step!(ctx)
......@@ -437,10 +394,10 @@ function optflow(imgf0, imgf1; name, params...)
interpolate!(ctx.g, g_optflow; ctx.u, f0, fw, nablafw)
save_optflow(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, f0, f1, fw)
pvd = paraview_collection("$(ctx.name).pvd")
pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_optflow(0)
for i in 1:10
step!(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment