Skip to content
Snippets Groups Projects
Commit da8b54db authored by Stephan Hilb's avatar Stephan Hilb
Browse files

fix interface issues and tests

parent 3e3a4ceb
Branches
Tags
No related merge requests found
......@@ -71,13 +71,13 @@ end
function alg_error(alg, pmin, niter, ninner=1)
print("run ...")
res = Float64[]
(p, ctx) = iterate(alg)
push!(res, error(ctx.p, pmin, ctx.algorithm.problem))
ctx = init(alg)
push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem))
for i in 1:niter
for j in 1:ninner
(p, ctx) = iterate(alg, ctx)
step!(ctx)
end
push!(res, error(ctx.p, pmin, ctx.algorithm.problem))
push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem))
end
println(" finished")
return res
......@@ -87,9 +87,9 @@ function calc_energy(prob, niter)
alg_ref = DualTVDD.ChambolleAlgorithm(prob)
ctx = init(alg_ref)
for i in 1:niter
(p, ctx) = iterate(alg_ref, ctx)
ctx = step!(ctx)
end
return energy(ctx), p
return energy(ctx), fetch(ctx)
end
function rundd()
......@@ -132,8 +132,8 @@ function rundd()
lognan.(alg_error(alg_dd2, pmin, n)),
]
plot(y, xaxis=:log, yaxis=:log)
plt = plot(y, xaxis=:log, yaxis=:log)
display(plt)
#display(energy(ctx))
#display(ctx.p)
......
......@@ -90,8 +90,7 @@ function step!(ctx::ChambolleState)
return ctx
end
fetch(ctx::ChambolleState) =
recover_u(ctx.p, ctx.algorithm.problem)
fetch(ctx::ChambolleState) = ctx.p
function recover_u(p, md::DualTVL1ROFOpProblem)
d = ndims(md.g)
......@@ -102,7 +101,7 @@ function recover_u(p, md::DualTVL1ROFOpProblem)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
# v = div(p) + A'*f
map!(, v, extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
map!(, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing
v .+= md.g
# u = B * v
mul!(vec(u), md.B, vec(v))
......
......@@ -29,7 +29,7 @@ struct DualTVDDState{A,d,V,SV,SAx,SC}
subctx::Array{SC,d}
end
function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
g = alg.problem.g
d = ndims(g)
ax = axes(g)
......@@ -56,9 +56,9 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
#subalg = [ProjGradAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
subalg = [ChambolleAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
subctx = [iterate(x)[2] for x in subalg]
subctx = [init(x) for x in subalg]
return p, DualTVDDState(alg, p, q, subax, subctx)
return DualTVDDState(alg, p, q, subax, subctx)
end
function intersectin(a, b)
......@@ -68,7 +68,8 @@ function intersectin(a, b)
return (c, c .- az, c .- bz)
end
function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
function step!(ctx::DualTVDDState)
alg = ctx.algorithm
# σ = 1 takes care of sequential updates
σ = alg.parallel ? ctx.algorithm.σ : 1.
d = ndims(ctx.p)
......@@ -103,7 +104,7 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
map!(, sg, sq)
for j in 1:alg.ninner
(_, ctx.subctx[i]) = iterate(ctx.subctx[i].algorithm, ctx.subctx[i])
step!(ctx.subctx[i])
end
end
......@@ -113,9 +114,11 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx)
ctx.p[sax...] .+= σ .* ctx.subctx[i].p
end
return ctx.p, ctx
return ctx
end
fetch(ctx::DualTVDDState) = ctx.p
@generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
......
......@@ -70,4 +70,4 @@ function step!(ctx::ProjGradState)
return ctx
end
fetch(ctx::ProjGradState) = recover_u(ctx.p, ctx.algorithm.problem)
fetch(ctx::ProjGradState) = ctx.p
......@@ -2,17 +2,35 @@ using Test, BenchmarkTools
using LinearAlgebra
using DualTVDD:
DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm,
init, step!, fetch
init, step!, fetch, recover_u
g = Float64[0 2; 1 0]
prob = DualTVL1ROFOpProblem(g, I, 1e-10)
@testset "B = I" begin
g = Float64[0 2; 1 0]
prob = DualTVL1ROFOpProblem(g, I, 1e-10)
@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
ctx = init(alg)
@test 0 == @ballocated step!($ctx)
for i in 1:100
step!(ctx)
@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
ctx = init(alg)
@test 0 == @ballocated step!($ctx)
for i in 1:100
step!(ctx)
end
u = recover_u(fetch(ctx), ctx.algorithm.problem)
@test u g
end
end
@testset "B = rand(...)" begin
g = Float64[0 2; 1 0]
B = rand(length(g), length(g))
prob = DualTVL1ROFOpProblem(g, B, 1e-10)
@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob))
ctx = init(alg)
@test 0 == @ballocated step!($ctx)
for i in 1:100
step!(ctx)
end
u = recover_u(fetch(ctx), ctx.algorithm.problem)
@test vec(u) B * vec(g)
end
u = fetch(ctx)
@test u g
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment