diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl index af27db1891d5deb99efbb610f5be9bc69e3e05d3..4fcafbd60c766079504ab15c871ce4fed0cd361d 100644 --- a/src/DualTVDD.jl +++ b/src/DualTVDD.jl @@ -71,13 +71,13 @@ end function alg_error(alg, pmin, niter, ninner=1) print("run ...") res = Float64[] - (p, ctx) = iterate(alg) - push!(res, error(ctx.p, pmin, ctx.algorithm.problem)) + ctx = init(alg) + push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem)) for i in 1:niter for j in 1:ninner - (p, ctx) = iterate(alg, ctx) + step!(ctx) end - push!(res, error(ctx.p, pmin, ctx.algorithm.problem)) + push!(res, error(fetch(ctx), pmin, ctx.algorithm.problem)) end println(" finished") return res @@ -87,9 +87,9 @@ function calc_energy(prob, niter) alg_ref = DualTVDD.ChambolleAlgorithm(prob) ctx = init(alg_ref) for i in 1:niter - (p, ctx) = iterate(alg_ref, ctx) + ctx = step!(ctx) end - return energy(ctx), p + return energy(ctx), fetch(ctx) end function rundd() @@ -132,8 +132,8 @@ function rundd() lognan.(alg_error(alg_dd2, pmin, n)), ] - plot(y, xaxis=:log, yaxis=:log) - + plt = plot(y, xaxis=:log, yaxis=:log) + display(plt) #display(energy(ctx)) #display(ctx.p) diff --git a/src/chambolle.jl b/src/chambolle.jl index 0a183cd803de5f663db841001fdb7c40c0ece8b9..148937967f4d2b48efa56b6eb5872561470cef3c 100644 --- a/src/chambolle.jl +++ b/src/chambolle.jl @@ -90,8 +90,7 @@ function step!(ctx::ChambolleState) return ctx end -fetch(ctx::ChambolleState) = - recover_u(ctx.p, ctx.algorithm.problem) +fetch(ctx::ChambolleState) = ctx.p function recover_u(p, md::DualTVL1ROFOpProblem) d = ndims(md.g) @@ -102,7 +101,7 @@ function recover_u(p, md::DualTVL1ROFOpProblem) kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) # v = div(p) + A'*f - map!(kΛ, v, extend(p, StaticKernels.ExtensionNothing())) # extension: nothing + map!(kΛ, v, StaticKernels.extend(p, StaticKernels.ExtensionNothing())) # extension: nothing v .+= md.g # u = B * v mul!(vec(u), md.B, vec(v)) diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl index b54df22cb5bd48975da28f4d564730ad32137783..f3925fff20f64b6949777a759208b2eebc6640a5 100644 --- a/src/dualtvdd.jl +++ b/src/dualtvdd.jl @@ -29,7 +29,7 @@ struct DualTVDDState{A,d,V,SV,SAx,SC} subctx::Array{SC,d} end -function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) +function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) g = alg.problem.g d = ndims(g) ax = axes(g) @@ -56,9 +56,9 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) #subalg = [ProjGradAlgorithm(subprobs[i]) for i in CartesianIndices(subax)] subalg = [ChambolleAlgorithm(subprobs[i]) for i in CartesianIndices(subax)] - subctx = [iterate(x)[2] for x in subalg] + subctx = [init(x) for x in subalg] - return p, DualTVDDState(alg, p, q, subax, subctx) + return DualTVDDState(alg, p, q, subax, subctx) end function intersectin(a, b) @@ -68,7 +68,8 @@ function intersectin(a, b) return (c, c .- az, c .- bz) end -function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx) +function step!(ctx::DualTVDDState) + alg = ctx.algorithm # σ = 1 takes care of sequential updates σ = alg.parallel ? ctx.algorithm.σ : 1. d = ndims(ctx.p) @@ -103,7 +104,7 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx) map!(kΛ, sg, sq) for j in 1:alg.ninner - (_, ctx.subctx[i]) = iterate(ctx.subctx[i].algorithm, ctx.subctx[i]) + step!(ctx.subctx[i]) end end @@ -113,9 +114,11 @@ function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx) ctx.p[sax...] .+= σ .* ctx.subctx[i].p end - return ctx.p, ctx + return ctx end +fetch(ctx::DualTVDDState) = ctx.p + @generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T} i0 = ntuple(_->0, N) i1(k) = ntuple(i->Int(k==i), N) diff --git a/src/projgrad.jl b/src/projgrad.jl index 6bbc2ed80a07e34894c37c88c8813852798516e2..e45080d956906ebc8bbe8db57fa5af6f608f9321 100644 --- a/src/projgrad.jl +++ b/src/projgrad.jl @@ -70,4 +70,4 @@ function step!(ctx::ProjGradState) return ctx end -fetch(ctx::ProjGradState) = recover_u(ctx.p, ctx.algorithm.problem) +fetch(ctx::ProjGradState) = ctx.p diff --git a/test/runtests.jl b/test/runtests.jl index fa69c23dd3914e7c776f63fd9ff2e1f315db00fa..e2debb324629c3ee87f20ee63d8d7e694bf11aa9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,17 +2,35 @@ using Test, BenchmarkTools using LinearAlgebra using DualTVDD: DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, - init, step!, fetch + init, step!, fetch, recover_u -g = Float64[0 2; 1 0] -prob = DualTVL1ROFOpProblem(g, I, 1e-10) +@testset "B = I" begin + g = Float64[0 2; 1 0] + prob = DualTVL1ROFOpProblem(g, I, 1e-10) -@testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob)) - ctx = init(alg) - @test 0 == @ballocated step!($ctx) - for i in 1:100 - step!(ctx) + @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob)) + ctx = init(alg) + @test 0 == @ballocated step!($ctx) + for i in 1:100 + step!(ctx) + end + u = recover_u(fetch(ctx), ctx.algorithm.problem) + @test u ≈ g + end +end + +@testset "B = rand(...)" begin + g = Float64[0 2; 1 0] + B = rand(length(g), length(g)) + prob = DualTVL1ROFOpProblem(g, B, 1e-10) + + @testset for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob)) + ctx = init(alg) + @test 0 == @ballocated step!($ctx) + for i in 1:100 + step!(ctx) + end + u = recover_u(fetch(ctx), ctx.algorithm.problem) + @test vec(u) ≈ B * vec(g) end - u = fetch(ctx) - @test u ≈ g end