From 7d9307a16dc775fe69ccb34c0abcbfd181f3f301 Mon Sep 17 00:00:00 2001 From: Stephan Hilb <stephan@ecshi.net> Date: Tue, 7 Jul 2020 15:37:04 +0200 Subject: [PATCH] move to new problem interface --- src/DualTVDD.jl | 92 ++++++++++++---------------------- src/chambolle.jl | 76 +++++++++++++++------------- src/common.jl | 78 ++++++++++++++++++++++++----- src/dualtvdd.jl | 128 +++++++++++++++++++++-------------------------- src/models.jl | 31 ------------ src/problems.jl | 36 +++++++++++++ src/projgrad.jl | 41 ++++++++------- test/runtests.jl | 13 ++--- 8 files changed, 263 insertions(+), 232 deletions(-) delete mode 100644 src/models.jl create mode 100644 src/problems.jl diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl index ab2cac4..cdbf66d 100644 --- a/src/DualTVDD.jl +++ b/src/DualTVDD.jl @@ -3,7 +3,7 @@ module DualTVDD include("common.jl") -include("models.jl") +include("problems.jl") include("chambolle.jl") include("dualtvdd.jl") include("projgrad.jl") @@ -25,7 +25,7 @@ function run() display(norm(B)) - md = DualTVDD.DualTVL1ROFOpModel(g, B, λ) + md = DualTVDD.DualTVL1ROFOpProblem(g, B, λ) ctx = DualTVDD.init(md, DualTVDD.ChambolleAlgorithm()) ctx2 = DualTVDD.init(md, DualTVDD.ProjGradAlgorithm(τ=1/sqrt(8)/norm(B))) @@ -59,73 +59,45 @@ function run() end function rundd() + λ = 1/2 β = 0 - f = zeros(4,4) - f[1,1] = 1 + f = zeros(10,10) + f[1] = 1 #f = [0. 2; 1 0.] - #A = diagm(vcat(fill(1, length(f)÷2), fill(1/10, length(f)÷2))) - #A = rand(length(f), length(f)) A = 0. * rand(length(f), length(f)) A .+= diagm(ones(length(f))) - - #display(A) B = inv(A'*A + β*I) - println(norm(A)) - - #println(norm(sqrt(B))) g = similar(f) vec(g) .= A' * vec(f) - λ = 1/4 + prob = DualTVDD.DualTVL1ROFOpProblem(g, B, λ) + alg_dd = DualTVDD.DualTVDDAlgorithm(prob, M=(2,2), overlap=(2,2), σ=0.25) + alg_ref = DualTVDD.ChambolleAlgorithm(prob) - md = DualTVDD.DualTVL1ROFOpModel(f, A, λ, 0., 0.) - alg = DualTVDD.DualTVDDAlgorithm(M=(2,2), overlap=(2,2), σ=0.25) - ctx = DualTVDD.init(md, alg) - - md2 = DualTVDD.DualTVL1ROFOpModel(g, B, λ) - alg2 = DualTVDD.ChambolleAlgorithm() - ctx2 = DualTVDD.init(md2, alg2) + (p, ctx) = iterate(alg_ref) + for i in 1:100000 + (p, ctx) = iterate(alg_ref, ctx) + end + display(ctx.p) + display(recover_u(p, prob)) + (p, ctx) = iterate(alg_dd) for i in 1:1000 - step!(ctx) - #println(energy(ctx)) + (p, ctx) = iterate(alg_dd, ctx) end - for i in 1:10000 - step!(ctx2) - #println(energy(ctx2)) - end - - #scene = heatmap(ctx.s, - # colorrange=(0,1), colormap=:gray, scale_plot=false) - - #display(scene) - - #hm = last(scene) - #for i in 1:100 - # step!(ctx) - # hm[1] = ctx.s - # yield() - # #sleep(0.2) - #end - #hm[1] = ctx.s - #yield() - - println("\np result") display(ctx.p) - display(ctx2.p) - - println("\nu result") - display(recover_u!(ctx)) - display(recover_u!(ctx2)) - - println(energy(ctx)) - println(energy(ctx2)) - - ctx, ctx2 + #display(ctx.subctx[1].algorithm.problem.g) + #display(ctx.subctx[2].algorithm.problem.g) + display(ctx.subctx[1].p) + display(ctx.subctx[2].p) + display(recover_u(p, prob)) + + #println(energy(ctx)) + #println(energy(ctx2)) end function run3() @@ -142,12 +114,12 @@ function run3() λ = 0.1 # Chambolle - md = DualTVDD.DualTVL1ROFOpModel(g, B, λ) + md = DualTVDD.DualTVL1ROFOpProblem(g, B, λ) alg = DualTVDD.ChambolleAlgorithm() ctx = DualTVDD.init(md, alg) # Projected Gradient - md = DualTVDD.DualTVL1ROFOpModel(f, A, λ, 0., 0.) + md = DualTVDD.DualTVL1ROFOpProblem(f, A, λ, 0., 0.) alg = DualTVDD.ProjGradAlgorithm(λ = 1/norm(A)^2) ctx2 = DualTVDD.init(md, alg) @@ -168,7 +140,7 @@ function run3() end -function energy(ctx::Union{DualTVDDContext,ProjGradContext}) +function energy(ctx::Union{DualTVDDState,ProjGradState}) d = ndims(ctx.p) @inline kfΛ(w) = @inbounds divergence(w) @@ -186,26 +158,26 @@ function energy(ctx::Union{DualTVDDContext,ProjGradContext}) return sum(u .* vec(v)) / 2 end -function energy(ctx::ChambolleContext) +function energy(ctx::ChambolleState) d = ndims(ctx.p) @inline kfΛ(w) = @inbounds divergence(w) kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) - v = similar(ctx.model.g) + v = similar(ctx.problem.g) # v = div(p) + g map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing())) - v .+= ctx.model.g + v .+= ctx.problem.g #display(v) # |v|_B^2 / 2 - u = ctx.model.B * vec(v) + u = ctx.problem.B * vec(v) return sum(u .* vec(v)) / 2 end -function energy(md::DualTVL1ROFOpModel, u::AbstractMatrix) +function energy(md::DualTVL1ROFOpProblem, u::AbstractMatrix) @inline kf(w) = @inbounds 1/2 * (w[0,0] - md.g[w.position])^2 + md.λ * sqrt((w[1,0] - w[0,0])^2 + (w[0,1] - w[0,0])^2) k = Kernel{(0:1, 0:1)}(kf, StaticKernels.ExtensionReplicate()) diff --git a/src/chambolle.jl b/src/chambolle.jl index 5d5b446..0edc303 100644 --- a/src/chambolle.jl +++ b/src/chambolle.jl @@ -13,20 +13,24 @@ using StaticKernels using StaticArrays: SVector using LinearAlgebra -struct ChambolleAlgorithm <: Algorithm +struct ChambolleAlgorithm{P} <: Algorithm{P} + problem::P + "fixed point inertia parameter" τ::Float64 - function ChambolleAlgorithm(; τ=1/8) - return new(τ) + + function ChambolleAlgorithm(problem; τ) + return new{typeof(problem)}(problem, τ) end end -struct ChambolleContext{M,A,T,R,S,Sv,K1,K2} <: Context - "model data" - model::M - "algorithm data" +ChambolleAlgorithm(problem::DualTVL1ROFOpProblem) = + ChambolleAlgorithm(problem, τ=inv(8 * normB(problem))) + +struct ChambolleState{A,T,R,S,Sv,K1,K2} <: State algorithm::A - "matrix view on model.f" + + "matrix view on .f" p::T "matrix view on rv" @@ -45,59 +49,63 @@ struct ChambolleContext{M,A,T,R,S,Sv,K1,K2} <: Context k2::K2 end -function init(md::DualTVL1ROFOpModel, alg::ChambolleAlgorithm) - d = ndims(md.g) - pv = zeros(d * length(md.g)) - rv = zeros(length(md.g)) +function proj(ctx::ChambolleState, x, s, λ) + @inbounds iszero(λ) && return zero(p) + return @inbounds p ./ (1 + norm(sgrad) / md.λ[sw.position]) +end + +function Base.iterate(alg::ChambolleAlgorithm{<:DualTVL1ROFOpProblem}) + g = alg.problem.g + λ = alg.problem.λ + d = ndims(g) + pv = zeros(d * length(g)) + rv = zeros(length(g)) sv = zero(rv) - p = extend(reshape(reinterpret(SVector{d,Float64}, pv), size(md.g)), StaticKernels.ExtensionNothing()) - r = reshape(reinterpret(Float64, rv), size(md.g)) - s = extend(reshape(reinterpret(Float64, sv), size(md.g)), StaticKernels.ExtensionReplicate()) + p = extend(reshape(reinterpret(SVector{d,Float64}, pv), size(g)), StaticKernels.ExtensionNothing()) + r = reshape(reinterpret(Float64, rv), size(g)) + s = extend(reshape(reinterpret(Float64, sv), size(g)), StaticKernels.ExtensionReplicate()) - @inline kf1(pw) = @inbounds divergence(pw) + md.g[pw.position] + @inline kf1(pw) = @inbounds divergence(pw) + g[pw.position] k1 = Kernel{ntuple(_->-1:1, d)}(kf1) - my_norm(B::UniformScaling{Bool}) = B.λ * sqrt(length(md.g)) - my_norm(B) = norm(B) - - normB = my_norm(md.B) - @inline function kf2(sw) - @inbounds iszero(md.λ[sw.position]) && return zero(p[sw.position]) - sgrad = (alg.τ / normB) * gradient(sw) - return @inbounds (p[sw.position] + sgrad) / (1 + norm(sgrad) / md.λ[sw.position]) + @inbounds iszero(λ[sw.position]) && return zero(p[sw.position]) + sgrad = alg.τ * gradient(sw) + return @inbounds (p[sw.position] + sgrad) / (1 + norm(sgrad) / λ[sw.position]) end k2 = Kernel{ntuple(_->0:1, d)}(kf2) - return ChambolleContext(md, alg, p, r, s, rv, sv, k1, k2) + return (p, ChambolleState(alg, p, r, s, rv, sv, k1, k2)) end -function reset!(ctx::ChambolleContext) +function reset!(ctx::ChambolleState) fill!(ctx.p, zero(eltype(ctx.p))) end -function step!(ctx::ChambolleContext) +function Base.iterate(alg::ChambolleAlgorithm{<:DualTVL1ROFOpProblem}, ctx) # r = div(p) + g map!(ctx.k1, ctx.r, ctx.p) # s = B * r - mul!(ctx.sv, ctx.model.B, ctx.rv) + mul!(ctx.sv, alg.problem.B, ctx.rv) # p = (p + τ*grad(s)) / (1 + τ/λ|grad(s)|) map!(ctx.k2, ctx.p, ctx.s) + + return (ctx.p, ctx) end -recover_u(ctx::ChambolleContext) = recover_u(ctx.p, ctx.model) +recover_u(ctx::ChambolleState) = recover_u(ctx.p, ctx.problem) # -#function solve(md::ROFModel, alg::Chambolle; +#function solve(md::ROFProblem, alg::Chambolle; # maxiters = 1000, # save_log = false) # -# p = zeros(ndims(md.g), size(md.g)...) +# p = zeros(ndims(g), size(g)...) # q = zero(p) -# s = zero(md.g) +# s = zero(g) # div = div_op(p) -# grad = grad_op(md.g) -# ctx = ChambolleContext(p, q, s, div, grad) +# grad = grad_op(g) +# ctx = ChambolleState(p, q, s, div, grad) # # log = Log() # diff --git a/src/common.jl b/src/common.jl index 3befb14..3643599 100644 --- a/src/common.jl +++ b/src/common.jl @@ -3,21 +3,75 @@ using StaticKernels # Solver Interface Notes: # -# - a concrete subtype `<:Model` describes one model type and each instance +# - a concrete subtype `<:Problem` describes one model type and each instance # represents a fully specified problem, i.e. including data and model parameters -# - a concrete subtype `<:Algorithm` describes an algorithm type, maybe -# applicable to different models with special implementations -# - abstract subtypes `<:Model`, `<:Algorithm` may exist -# - `init(::Model, ::Algorithm)::Context` initializes a runnable algorithm context -# - `step!(::Context)::Context` performs one non-allocating step -# - `solve(::Model, ::Algorithm)` must lead to the same deterministic -# outcome depending only on the arguments given. +# - a concrete subtype `<:Algorithm` describes an algorithm, constructed from a +# model and algoritm parameters as keyword arguments +# - a concrete subtype `<:Context` contains preallocated data for the +# algorithm +# +# - `init(::Problem, ::Algorithm)::Context` initializes the algorithm context +# - `iterate!(::Context)::Union{Context,Nothing}` performs one non-allocating step +# +# - iterating over the same context must lead to the same deterministic +# sequence of iterates # - `run(::Context)` must continue on from the given context # - `(<:Algorithm)(...)` constructors must accept keyword arguments only +# - types `<:Context` satisfy a stateful iterator interface, returning the +# current context at each step of the algorithm + +""" + <:Problem + +The abstract type hierarchy specifies the problem model interface (problem, +available data, available oracles). +""" +abstract type Problem end + +""" + <:Algorithm{<:Problem} + (<:Algorithm)(::Problem; params...) + +An algorithm represents a fully specified iterative process of usually infite +or a-priori unknown finite length, producing iterates that approximate the +solution to some problem (see `::Problem`). + +The hierarchy of abstract types `<:Algorithm` is based on the algorithm +interface (e.g. accepted parameters) and the specific numerical scheme. + +A concrete type `<:Algorithm` represents an implementation of its supertype +algorithm. An instance is a complete specification of the algorithm with all +its inputs and should guarantee to produce a deterministic sequence of iterates +(randomized algorithms should use a seeded pseudorandom number generator). + +An algorithm needs to implement + + - `(x0, state) = Base.iterate(::Algorithm)` allocates and initializes the + algorithm state and returns the initial iterate `x0`. + - `(x, state) = Base.iterate(::Algorithm, state)` performs one iteration of the + algorithm by updating `state` and returning the updated iterate `x`. This + method should not allocate dynamic memory. +""" +abstract type Algorithm{P<:Problem} end + +Base.IteratorSize(::Type{<:Algorithm}) = Base.SizeUnknown() + +function execute(alg::Algorithm; maxiters=nothing) + k = 0 + y = iterate(alg) + @assert !isnothing(y) + (x, state) = y + while true + k += 1 + k > maxiters && break + y = iterate(alg, state) + isnothing(y) && break + (x, state) = y + end + return x +end -abstract type Model end -abstract type Algorithm end -abstract type Context end +abstract type State end "helper struct to prevent allocations" @@ -44,7 +98,7 @@ end i0 = ntuple(_->0, N) i1(k) = ntuple(i->Int(k==i), N) - wi = (:((isnothing(w[$(i1(k)...)]) ? zero(T) : w[$(i0...)][$k]) - + wi = (:((isnothing(w[$(i0...)]) ? zero(T) : w[$(i0...)][$k]) - (isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N) return quote Base.@_inline_meta diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl index 7077bf1..456ebae 100644 --- a/src/dualtvdd.jl +++ b/src/dualtvdd.jl @@ -1,100 +1,99 @@ -struct DualTVDDAlgorithm{d} <: Algorithm +struct DualTVDDAlgorithm{P,d} <: Algorithm{P} + problem::P + "number of subdomains in each dimension" M::NTuple{d,Int} "overlap in pixels per dimension" overlap::NTuple{d,Int} "inertia parameter" σ::Float64 - function DualTVDDAlgorithm(; M, overlap, σ) - return new{length(M)}(M, overlap, σ) + function DualTVDDAlgorithm(problem; M, overlap, σ) + return new{typeof(problem), length(M)}(problem, M, overlap, σ) end end -struct DualTVDDContext{M,A,G,d,U,V,Vtmp,VV,SAx,SC} - model::M +struct DualTVDDState{A,d,V,SV,SAx,SC} algorithm::A - "precomputed A'f" - g::G - "global dual optimization variable" + + "global variable" p::V - "global dual temporary variable" - ptmp::Vtmp - "precomputed (A'A + βI)^(-1)" - B::VV + "local buffer" # TODO: get rid of this + q::Array{SV,d} "subdomain axes wrt global indices" subax::SAx - "subproblem data, subg[i] == subctx[i].model.g" - subg::Array{U,d} "context for subproblems" subctx::Array{SC,d} end -function init(md::DualTVL1ROFOpModel, alg::DualTVDDAlgorithm) - d = ndims(md.f) - ax = axes(md.f) +function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) + g = alg.problem.g + d = ndims(g) + ax = axes(g) # subdomain axes - subax = subaxes(size(md.f), alg.M, alg.overlap) + subax = subaxes(size(g), alg.M, alg.overlap) # preallocated data for subproblems - subg = [Array{Float64, d}(undef, length.(ax)) for i in CartesianIndices(subax)] + subg = [Array{Float64, d}(undef, length.(x)) for x in subax] # locally dependent tv parameter - subλ = [md.λ .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(ax)) + subλ = [alg.problem.λ[subax[i]...] .* theta.(Ref(ax), Ref(subax[i]), Ref(alg.overlap), CartesianIndices(subax[i])) for i in CartesianIndices(subax)] - # this is the global g, the local gs are getting initialized in step!() - g = reshape(md.A' * vec(md.f), size(md.f)) - - # global dual variables - p = zeros(SVector{d,Float64}, size(md.f)) - ptmp = extend(zeros(SVector{d,Float64}, size(md.f)), StaticKernels.ExtensionNothing()) - - # precomputed global B - B = inv(md.A' * md.A + md.β * I) - #B = diagm(ones(length(md.f))) + md.β * I + # global dual variable + p = zeros(SVector{d,Float64}, size(g)) + # local buffer variables + q = [extend(zeros(SVector{d,Float64}, length.(x)), StaticKernels.ExtensionNothing()) for x in subax] # create subproblem contexts - li = LinearIndices(size(md.f)) - submds = [DualTVL1ROFOpModel(subg[i], B, subλ[i]) + li = LinearIndices(ax) + subprobs = [DualTVL1ROFOpProblem(subg[i], op_restrict(alg.problem.B, ax, subax[i]), subλ[i]) for i in CartesianIndices(subax)] - subalg = ProjGradAlgorithm(λ=1/sqrt(8)/norm(B)) - #subalg = ChambolleAlgorithm() - subctx = [init(submds[i], subalg) for i in CartesianIndices(subax)] + # TODO: make inner algorithm a parameter + #subalg = [ProjGradAlgorithm(subprobs[i]) for i in CartesianIndices(subax)] + subalg = [ChambolleAlgorithm(subprobs[i]) for i in CartesianIndices(subax)] - # subcontext B is identity - #B = inv(md.A' * md.A + md.β * I) + subctx = [iterate(x)[2] for x in subalg] - return DualTVDDContext(md, alg, g, p, ptmp, B, subax, subg, subctx) + return p, DualTVDDState(alg, p, q, subax, subctx) end -function step!(ctx::DualTVDDContext) +function Base.iterate(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}, ctx) σ = ctx.algorithm.σ d = ndims(ctx.p) ax = axes(ctx.p) overlap = ctx.algorithm.overlap - @inline kfΛ(w) = @inbounds divergence_global(w) - kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) - # call run! on each cell (this can be threaded) for (i, sax) in pairs(ctx.subax) - sg = ctx.subg[i] # julia-bug workaround + sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround + sq = ctx.q[i] # julia-bug workaround + + #println("# subax $sax") + + sg .= alg.problem.g[sax...] + sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* ctx.p[sax...] - # TODO: make p computation local! - ctx.ptmp .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p))) .* ctx.p - #map!(kΛ, sg, ctx.ptmp) - #sg .+= ctx.g + @inline kfΛ(pw) = @inbounds sg[pw.position] + divergence(pw) + kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) - for j in 1:1000 - step!(ctx.subctx[i]) + map!(kΛ, sg, sq) + + for j in 1:10 + (_, ctx.subctx[i]) = iterate(ctx.subctx[i].algorithm, ctx.subctx[i]) end + #display(ctx.subctx[i].algorithm.problem.g .- alg.problem.g[sax...]) + #display(ctx.subctx[i].p) end + #display(ctx.p) # aggregate (not thread-safe!) ctx.p .*= 1 - σ for (i, sax) in pairs(ctx.subax) - ctx.p .+= σ .* ctx.subctx[i].p + ctx.p[sax...] .+= σ .* ctx.subctx[i].p end + #display(ctx.p) + + return ctx.p, ctx end @generated function divergence_global(w::StaticKernels.Window{SVector{N,T},N}) where {N,T} @@ -109,37 +108,26 @@ end end end - - #FD.GridFunction(grid, (A'*A + β*I) \ (FD.divergence_z(p).data[:] .+ A'*f.data[:])) - -function recover_u!(ctx::DualTVDDContext) - d = ndims(ctx.g) - u = similar(ctx.g) - v = similar(ctx.g) - - @inline kfΛ(w) = @inbounds divergence(w) - kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) - - # v = div(p) + A'*f - map!(kΛ, v, extend(ctx.p, StaticKernels.ExtensionNothing())) - v .+= ctx.g - # u = B * v - mul!(vec(u), ctx.B, vec(v)) - return u +op_restrict(B::UniformScaling, ax, sax) = B +function op_restrict(B, ax, sax) + li = LinearIndices(ax) + si = vec(li[sax...]) + return B[si, si] end + """ theta(ax, sax, overlap, i) Return value of the partition function at index `i` given global axes `ax`, subdomain axes `sax` and overlap count `overlap`. -This assumes that subdomains have size at least 2 .* overlap. +This assumes that interior subdomains have size at least `2 .* overlap`. """ theta(ax, sax, overlap, i::CartesianIndex) = prod(theta.(ax, sax, overlap, Tuple(i))) theta(ax, sax, overlap::Int, i::Int) = max(0., min(1., - first(ax) == first(sax) && i < first(ax) + overlap ? 1. : (i - first(sax) + 1) / (overlap + 1), - last(ax) == last(sax) && i > last(ax) - overlap ? 1. : (last(sax) - i + 1) / (overlap + 1))) + first(ax) == first(sax) && i < first(ax) + overlap ? 1. : (i - first(sax)) / (overlap - 1), + last(ax) == last(sax) && i > last(ax) - overlap ? 1. : (last(sax) - i) / (overlap - 1))) """ diff --git a/src/models.jl b/src/models.jl deleted file mode 100644 index c13bbc0..0000000 --- a/src/models.jl +++ /dev/null @@ -1,31 +0,0 @@ -#"min_u 1/2 * |Au-f|_2^2 + λ*TV(u) + β/2 |u|_2 + γ |u|_1" -"min_p 1/2 * |p₂ - div(p₂) - g|_B^2 + χ_{|p₁|≤λ} + χ_{|p₂|≤γ}" -struct DualTVL1ROFOpModel{Tg,TB,Tλ,Tγ} <: Model - "dual data" - g::Tg - "B operator" - B::TB - "TV regularization parameter" - λ::Tλ - "L1 regularization parameter" - γ::Tγ -end - -DualTVL1ROFOpModel(g, B, λ) = DualTVL1ROFOpModel(g, B, λ, nothing) -DualTVL1ROFOpModel(g, B, λ::Real) = DualTVL1ROFOpModel(g, B, fill!(similar(g), λ)) - -function recover_u(p, md::DualTVL1ROFOpModel) - d = ndims(md.g) - u = similar(md.g) - v = similar(md.g) - - @inline kfΛ(w) = @inbounds divergence(w) - kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) - - # v = div(p) + A'*f - map!(kΛ, v, p) # extension: nothing - v .+= md.g - # u = B * v - mul!(vec(u), md.B, vec(v)) - return u -end diff --git a/src/problems.jl b/src/problems.jl new file mode 100644 index 0000000..7e7fc7d --- /dev/null +++ b/src/problems.jl @@ -0,0 +1,36 @@ +using LinearAlgebra: UniformScaling + +#"min_u 1/2 * |Au-f|_2^2 + λ*TV(u) + β/2 |u|_2 + γ |u|_1" +"min_p 1/2 * |p₂ - div(p₁) - g|_B^2 + χ_{|p₁|≤λ} + χ_{|p₂|≤γ}" +struct DualTVL1ROFOpProblem{Tg,TB,Tλ,Tγ} <: Problem + "dual data" + g::Tg + "B operator" + B::TB + "TV regularization parameter" + λ::Tλ + "L1 regularization parameter" + γ::Tγ +end + +DualTVL1ROFOpProblem(g, B, λ) = DualTVL1ROFOpProblem(g, B, λ, nothing) +DualTVL1ROFOpProblem(g, B, λ::Real) = DualTVL1ROFOpProblem(g, B, fill!(similar(g), λ)) + +normB(p::DualTVL1ROFOpProblem{<:Any,<:UniformScaling}) = p.B.λ * sqrt(length(p.g)) +normB(p::DualTVL1ROFOpProblem) = norm(p.B) + +function recover_u(p, md::DualTVL1ROFOpProblem) + d = ndims(md.g) + u = similar(md.g) + v = similar(md.g) + + @inline kfΛ(w) = @inbounds divergence(w) + kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ) + + # v = div(p) + A'*f + map!(kΛ, v, extend(p, StaticKernels.ExtensionNothing())) # extension: nothing + v .+= md.g + # u = B * v + mul!(vec(u), md.B, vec(v)) + return u +end diff --git a/src/projgrad.jl b/src/projgrad.jl index 6071fbf..1de6503 100644 --- a/src/projgrad.jl +++ b/src/projgrad.jl @@ -1,13 +1,17 @@ -struct ProjGradAlgorithm <: Algorithm +struct ProjGradAlgorithm{P} <: Algorithm{P} + problem::P + "gradient step size" τ::Float64 - function ProjGradAlgorithm(; τ) - return new(τ) + function ProjGradAlgorithm(problem; τ) + return new{typeof(problem)}(problem, τ) end end -struct ProjGradContext{M,A,Tp,Wv,R,S,K1,K2} - model::M +ProjGradAlgorithm(problem::DualTVL1ROFOpProblem) = + ProjGradAlgorithm(problem, τ=inv(8 * normB(problem))) + +struct ProjGradState{A,Tp,Wv,R,S,K1,K2} algorithm::A "dual optimization variable" p::Tp @@ -26,40 +30,43 @@ struct ProjGradContext{M,A,Tp,Wv,R,S,K1,K2} k2::K2 end -function init(md::DualTVL1ROFOpModel, alg::ProjGradAlgorithm) - d = ndims(md.g) - ax = axes(md.g) +function Base.iterate(alg::ProjGradAlgorithm{<:DualTVL1ROFOpProblem}) + g = alg.problem.g + d = ndims(g) + ax = axes(g) p = extend(zeros(SVector{d,Float64}, ax), StaticKernels.ExtensionNothing()) - g = extend(md.g, StaticKernels.ExtensionNothing()) + g = extend(g, StaticKernels.ExtensionNothing()) - rv = zeros(length(md.g)) - sv = zeros(length(md.g)) + rv = zeros(length(g)) + sv = zeros(length(g)) r = reshape(rv, ax) s = extend(reshape(sv, ax), StaticKernels.ExtensionReplicate()) z = zero(CartesianIndex{d}) - @inline kf1(pw) = @inbounds -divergence(pw) - md.g[pw.position] + @inline kf1(pw) = @inbounds -divergence(pw) - g[pw.position] k1 = Kernel{ntuple(_->-1:1, d)}(kf1) @inline function kf2(pw, sw) @inbounds q = pw[z] - alg.τ * gradient(sw) - return @inbounds q / max(norm(q) / md.λ[sw.position], 1) + return @inbounds q / max(norm(q) / alg.problem.λ[sw.position], 1) end k2 = Kernel{ntuple(_->0:1, d)}(kf2) - return ProjGradContext(md, alg, p, rv, sv, r, s, k1, k2) + return (p, ProjGradState(alg, p, rv, sv, r, s, k1, k2)) end -function step!(ctx::ProjGradContext) +function Base.iterate(alg::ProjGradAlgorithm{<:DualTVL1ROFOpProblem}, ctx) # r = Λ*p - g map!(ctx.k1, ctx.r, ctx.p) # s = B * r - mul!(ctx.sv, ctx.model.B, ctx.rv) + mul!(ctx.sv, alg.problem.B, ctx.rv) # p = proj(p - λΛ's) map!(ctx.k2, ctx.p, ctx.p, ctx.s) + + return (ctx.p, ctx) end -recover_u(ctx::ProjGradContext) = recover_u(ctx.p, ctx.model) +recover_u(ctx::ProjGradState) = recover_u(ctx.p, ctx.problem) diff --git a/test/runtests.jl b/test/runtests.jl index a540d6c..9e8d0d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,16 +1,13 @@ using Test using LinearAlgebra using DualTVDD: - DualTVL1ROFOpModel, ProjGradAlgorithm, ChambolleAlgorithm, + DualTVL1ROFOpProblem, ProjGradAlgorithm, ChambolleAlgorithm, init, step!, recover_u g = Float64[0 2; 1 0] -md = DualTVL1ROFOpModel(g, I, 1e-10) +prob = DualTVL1ROFOpProblem(g, I, 1e-10) -for alg in (ProjGradAlgorithm(τ=1/8), ChambolleAlgorithm()) - ctx = init(md, alg) - for i in 1:100 - step!(ctx) - end - @test recover_u(ctx) ≈ g +for alg in (ProjGradAlgorithm(prob, τ=1/8), ChambolleAlgorithm(prob)) + its = collect(Iterators.take(alg, 100)) + @test recover_u(last(its)) ≈ g end -- GitLab