Skip to content
Snippets Groups Projects
Commit 7daf94ff authored by Stephan Hilb's avatar Stephan Hilb
Browse files

fix energy calculation

parent 754a746b
No related branches found
No related tags found
No related merge requests found
*.vtu *.vtu
*.pvd *.pvd
data/*
...@@ -21,6 +21,12 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da" ...@@ -21,6 +21,12 @@ git-tree-sha1 = "ded953804d019afa9a3f98981d99b33e3db7b6da"
uuid = "944b1d66-785c-5afd-91f1-9de20f533193" uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
version = "0.7.0" version = "0.7.0"
[[ColorTypes]]
deps = ["FixedPointNumbers", "Random"]
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597"
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
version = "0.11.0"
[[CommonSubexpressions]] [[CommonSubexpressions]]
deps = ["MacroTools", "Test"] deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
...@@ -71,12 +77,24 @@ version = "0.8.5" ...@@ -71,12 +77,24 @@ version = "0.8.5"
deps = ["ArgTools", "LibCURL", "NetworkOptions"] deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
[[FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
git-tree-sha1 = "256d8e6188f3f1ebfa1a5d17e072a0efafa8c5bf"
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
version = "1.10.1"
[[FillArrays]] [[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"] deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a" git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.7" version = "0.11.7"
[[FixedPointNumbers]]
deps = ["Statistics"]
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc"
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.4"
[[ForwardDiff]] [[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6" git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
...@@ -193,6 +211,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" ...@@ -193,6 +211,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"] deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Requires]]
deps = ["UUIDs"]
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.1.3"
[[SHA]] [[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
......
...@@ -4,6 +4,8 @@ authors = ["Stephan Hilb <stephan@ecshi.net>"] ...@@ -4,6 +4,8 @@ authors = ["Stephan Hilb <stephan@ecshi.net>"]
version = "0.1.0" version = "0.1.0"
[deps] [deps]
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
......
...@@ -202,7 +202,7 @@ function integrate(mesh::Mesh, expr; params...) ...@@ -202,7 +202,7 @@ function integrate(mesh::Mesh, expr; params...)
for k in axes(qx, 2) for k in axes(qx, 2)
xhat = qx[:, k] xhat = qx[:, k]
x = elmap(mesh, cell)(xhat) x = elmap(mesh, cell)(xhat)
opvalues = map(f -> evaluate(f, x), opparams) opvalues = map(f -> evaluate(f, xhat), opparams)
val += qw[k] * expr(x; opvalues...) * intel val += qw[k] * expr(x; opvalues...) * intel
end end
...@@ -211,6 +211,7 @@ function integrate(mesh::Mesh, expr; params...) ...@@ -211,6 +211,7 @@ function integrate(mesh::Mesh, expr; params...)
end end
function bind!(f::FeFunction, cell) function bind!(f::FeFunction, cell)
# TODO: make this non-allocating
f.ldata .= vec(f.data[f.space.dofmap[:, :, cell]]) f.ldata .= vec(f.data[f.space.dofmap[:, :, cell]])
return f return f
end end
......
export myrun, denoise, inpaint, optflow, solve_primal!, estimate! export myrun, denoise, inpaint, optflow, solve_primal!, estimate!, loadimg, saveimg
using LinearAlgebra: norm using LinearAlgebra: norm
# avoid world-age-issues by preloading ColorTypes
import ColorTypes
import FileIO
struct L1L2TVContext{M, Ttype, Stype} struct L1L2TVContext{M, Ttype, Stype}
name::String name::String
...@@ -179,9 +183,9 @@ function solve_primal!(u::FeFunction, ctx::L1L2TVContext) ...@@ -179,9 +183,9 @@ function solve_primal!(u::FeFunction, ctx::L1L2TVContext)
u.data .= A \ b u.data .= A \ b
end end
function estimate!(ctx::L1L2TVContext) huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
huber(x, gamma) = abs(x) < gamma ? x^2 / (2 * gamma) : abs(x) - gamma / 2
function estimate!(ctx::L1L2TVContext)
function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata) function estf(x_; g, u, p1, p2, nablau, w, nablaw, tdata)
alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * ( alpha1part = iszero(ctx.alpha1) ? 0. : ctx.alpha1 * (
huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) + huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
...@@ -252,7 +256,7 @@ function mark(ctx::L1L2TVContext; theta=0.5) ...@@ -252,7 +256,7 @@ function mark(ctx::L1L2TVContext; theta=0.5)
end end
function save(ctx::L1L2TVContext, filename, fs...) function output(ctx::L1L2TVContext, filename, fs...)
print("save \"$filename\" ... ") print("save \"$filename\" ... ")
vtk = vtk_mesh(filename, ctx.mesh) vtk = vtk_mesh(filename, ctx.mesh)
vtk_append!(vtk, fs...) vtk_append!(vtk, fs...)
...@@ -260,67 +264,18 @@ function save(ctx::L1L2TVContext, filename, fs...) ...@@ -260,67 +264,18 @@ function save(ctx::L1L2TVContext, filename, fs...)
return vtk return vtk
end end
function test_refine(mesh = nothing) toimg(arr) = Gray.(clamp.(arr, 0., 1.))
m = 1 loadimg(x) = reverse(transpose(Float64.(FileIO.load(x))); dims = 2)
saveimg(io, x) = FileIO.save(io, toimg(x))
if isnothing(mesh) function primal_energy(ctx::L1L2TVContext)
mesh = init_grid(2, 2) function integrand(x; g, u, nablau, tdata)
return ctx.alpha1 * huber(norm(ctx.T(tdata, u) - g), ctx.gamma1) +
ctx.alpha2 / 2 * norm(ctx.T(tdata, u) - g)^2 +
ctx.beta / 2 * norm(ctx.S(u, nablau))^2 +
ctx.lambda * huber(norm(nablau), ctx.gamma2)
end end
return integrate(ctx.mesh, integrand; ctx.g, ctx.u, ctx.nablau, ctx.tdata)
T(tdata, u) = u
S(u, nablau) = u
ctx = L1L2TVContext("test", mesh, m; T, tdata = nothing, S,
alpha1=0., alpha2=5., lambda=0.1, beta=1e-3, gamma1=1e-3, gamma2=1e-3)
interpolate!(ctx.g, x -> x[1] + x[2] > 1.01)
mysave(ctx, i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("$(ctx.name).pvd")
for i in 1:10
step!(ctx)
end
pvd[0] = mysave(ctx, 0)
ctx2 = refine(ctx, [1, 2, 3, 4, 5, 6, 7, 8])
#interpolate!(ctx2.g, x -> x[1] + x[2] < 2.01)
#interpolate!(ctx2.g, x -> norm(x .- m) < norm(m) / 3)
#ctx2.est.data .= 0
#ctx2.g.data .= 0
#ctx2.u.data .= 0
#ctx2.p1.data .= 0
#ctx2.p2.data .= 0
#ctx2.du.data .= 0
#ctx2.dp1.data .= 0
#ctx2.dp2.data .= 0
for i in 1:10
step!(ctx2)
end
pvd[1] = mysave(ctx2, 1)
test_mesh(ctx2.mesh)
return ctx, ctx2
for i in 1:5
step!(ctx)
estimate!(ctx)
pvd[i] = mysave(i)
println()
end
estmean = mean(ctx.est.data)
marked_cells = findall(>(2*estmean), ctx.est.data)
println(marked_cells)
ctx = refine(ctx, marked_cells)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
for i in 6:11
step!(ctx)
estimate!(ctx)
pvd[i] = mysave(i)
println()
end
return ctx
end end
norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f)) norm_l2(f) = sqrt(integrate(f.space.mesh, (x; f) -> dot(f, f); f))
...@@ -336,16 +291,17 @@ function denoise(img; name, params...) ...@@ -336,16 +291,17 @@ function denoise(img; name, params...)
interpolate!(ctx.g, x -> interpolate_bilinear(img, x)) interpolate!(ctx.g, x -> interpolate_bilinear(img, x))
m = (size(img) .- 1) ./ 2 .+ 1 m = (size(img) .- 1) ./ 2 .+ 1
interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
save_denoise(ctx, i) = save_denoise(ctx, i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est) ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est)
pvd = paraview_collection("$(ctx.name).pvd") pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_denoise(ctx, 0) pvd[0] = save_denoise(ctx, 0)
k = 0 k = 0
println("primal energy: $(primal_energy(ctx))")
while true while true
while true while true
k += 1 k += 1
...@@ -354,6 +310,7 @@ function denoise(img; name, params...) ...@@ -354,6 +310,7 @@ function denoise(img; name, params...)
pvd[k] = save_denoise(ctx, k) pvd[k] = save_denoise(ctx, k)
println() println()
println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))") println("ndofs: $(ndofs(ctx.u.space)), est: $(norm_l2(ctx.est)))")
println("primal energy: $(primal_energy(ctx))")
norm_step = sqrt(norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2)) norm_step = sqrt(norm_l2(ctx.du)^2 + norm_l2(ctx.dp1)^2 + norm_l2(ctx.dp2))
norm_step <= 1e-3 && break norm_step <= 1e-3 && break
...@@ -363,7 +320,7 @@ function denoise(img; name, params...) ...@@ -363,7 +320,7 @@ function denoise(img; name, params...)
println("refining ...") println("refining ...")
ctx = refine(ctx, marked_cells) ctx = refine(ctx, marked_cells)
test_mesh(ctx.mesh) test_mesh(ctx.mesh)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3) #interpolate!(ctx.g, x -> norm(x .- m) < norm(m .- 1) / 3)
k >= 200 && break k >= 200 && break
end end
vtk_save(pvd) vtk_save(pvd)
...@@ -394,10 +351,10 @@ function inpaint(img, imgmask; name, params...) ...@@ -394,10 +351,10 @@ function inpaint(img, imgmask; name, params...)
interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3) interpolate!(ctx.g, x -> norm(x .- m) < norm(m) / 3)
save_inpaint(i) = save_inpaint(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask) ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, mask)
pvd = paraview_collection("$(ctx.name).pvd") pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_inpaint(0) pvd[0] = save_inpaint(0)
for i in 1:3 for i in 1:3
step!(ctx) step!(ctx)
...@@ -437,10 +394,10 @@ function optflow(imgf0, imgf1; name, params...) ...@@ -437,10 +394,10 @@ function optflow(imgf0, imgf1; name, params...)
interpolate!(ctx.g, g_optflow; ctx.u, f0, fw, nablafw) interpolate!(ctx.g, g_optflow; ctx.u, f0, fw, nablafw)
save_optflow(i) = save_optflow(i) =
save(ctx, "$(ctx.name)_$(lpad(i, 5, '0')).vtu", output(ctx, "output/$(ctx.name)_$(lpad(i, 5, '0')).vtu",
ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, f0, f1, fw) ctx.g, ctx.u, ctx.p1, ctx.p2, ctx.est, f0, f1, fw)
pvd = paraview_collection("$(ctx.name).pvd") pvd = paraview_collection("output/$(ctx.name).pvd")
pvd[0] = save_optflow(0) pvd[0] = save_optflow(0)
for i in 1:10 for i in 1:10
step!(ctx) step!(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment