Skip to content
Snippets Groups Projects
Commit b618597f authored by Stephan Hilb's avatar Stephan Hilb
Browse files

implement specialized chambolle

parent 3ea29b47
No related branches found
No related tags found
No related merge requests found
# This file is machine-generated - editing it directly is not advised
[[Libdl]]
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[[LinearAlgebra]]
deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.12.3"
[[StaticKernels]]
path = "/home/stev47/.julia/dev/StaticKernels"
uuid = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7"
version = "0.4.0"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
......@@ -3,6 +3,11 @@ uuid = "93adc0ee-851f-4b8b-8bf8-c8a87ded093b"
authors = ["Stephan Hilb <stephan@ecshi.net>"]
version = "0.1.0"
[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticKernels = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7"
[compat]
julia = "1.4.0"
......
module DualTVDD
greet() = print("Hello World!")
include("types.jl")
include("chambolle.jl")
end # module
# Chambolle, A.
# An Algorithm for Total Variation Minimization and Applications.
# Journal of Mathematical Imaging and Vision 20, 89–97 (2004).
# https://doi.org/10.1023/B:JMIV.0000011325.36760.1e
#
# Implementation Notes:
# - TV-parameter α instead of λ
# - feasibility constraint |p|<=α instead of |p|<=1
# - B is introduced, the original Chambolle algorithm has B=Id
using Base: @_inline_meta
using StaticKernels
using StaticArrays: SVector
using LinearAlgebra
struct ChambolleAlgorithm <: Algorithm
"fixed point inertia parameter"
τ::Float64
function ChambolleAlgorithm(; τ=1/4)
return new(τ)
end
end
struct ChambolleContext{M,A,G,T,R,S,Sv,K1,K2} <: Context
"model data"
model::M
"algorithm data"
algorithm::A
"matrix view on model.f"
g::G
"matrix view on pv"
p::T
"matrix view on rv"
r::R
"matrix view on sv"
s::S
"dual variable as vector"
rv::Sv
"scalar temporary"
sv::Sv
"div(p) + g kernel"
k1::K1
"(p + τ*grad(q))/(1 + τ/α|grad(q)|) kernel"
k2::K2
end
function init(md::ChambolleModel, alg::ChambolleAlgorithm)
d = ndims(md.g)
g = extend(md.g, StaticKernels.ExtensionNothing())
pv = zeros(d * length(md.g))
rv = zeros(length(md.g))
sv = zero(rv)
p = extend(reshape(reinterpret(SVector{d,Float64}, pv), size(md.g)), StaticKernels.ExtensionNothing())
r = reshape(reinterpret(Float64, rv), size(md.g))
s = extend(reshape(reinterpret(Float64, sv), size(md.g)), StaticKernels.ExtensionConstant(0.))
z = zero(CartesianIndex{d})
@inline kf1(pw, gw) = @inbounds divergence(pw) + gw[z]
k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
@inline function kf2(pw, sw)
sgrad = alg.τ * gradient(sw)
return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / md.λ)
end
k2 = Kernel{ntuple(_->0:1, d)}(kf2)
return ChambolleContext(md, alg, g, p, r, s, rv, sv, k1, k2)
end
@generated function gradient(w::StaticKernels.Window{<:Any,N}) where N
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
wi = (:(w[$(i1(k)...)] - w[$(i0...)]) for k in 1:N)
return quote
Base.@_inline_meta
return @inbounds SVector($(wi...))
end
end
@generated function divergence(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
wi = (:((isnothing(w[$(i1(k)...)]) ? zero(T) : w[$(i0...)][$k]) -
(isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N)
return quote
Base.@_inline_meta
return @inbounds +($(wi...))
end
end
# x = (A'*A + β*I) \ (FD.divergence(grid, p) .+ A'*f)
# q = FD.gradient(grid, x)
function step!(ctx::ChambolleContext)
# r = div(p) + g
map!(ctx.k1, ctx.r, ctx.p, ctx.g)
# s = B * r
mul!(ctx.sv, ctx.model.B, ctx.rv)
# p = (p + τ*grad(s)) / (1 + τ/α|grad(s)|)
map!(ctx.k2, ctx.p, ctx.p, ctx.s)
end
function recover_u!(ctx)
# r = div(p) + g
map!(ctx.k1, ctx.r, ctx.p, ctx.g)
# s = B * r
mul!(ctx.sv, ctx.model.B, ctx.rv)
end
#
#function solve(md::ROFModel, alg::Chambolle;
# maxiters = 1000,
# save_log = false)
#
# p = zeros(ndims(md.g), size(md.g)...)
# q = zero(p)
# s = zero(md.g)
# div = div_op(p)
# grad = grad_op(md.g)
# ctx = ChambolleContext(p, q, s, div, grad)
#
# log = Log()
#
# for i in 1:maxiters
# step!(ctx, md, alg)
#
# if save_log
# recover_u!(ctx, md, alg)
# save_log && push!(log.energy, energy2(md, ctx.s))
# end
# end
#
# recover_u!(ctx, md, alg)
# return Solution(ctx.s, md, alg, ctx, log)
#end
struct DualTVDDAlgorithm{d} <: Algorithm
"number of subdomains in each dimension"
M::NTuple{d,Int}
"inertia parameter"
σ::Float64
function DualTVDDAlgorithm(; M, σ)
return new{length(M)}(M, σ)
end
end
struct DualTVDDContext{d,U,V,Vview,SC}
"global dual optimization variable"
p::V
"local views on p per subdomain"
pviews::Array{Vview,d}
"data for subproblems"
g::Array{U,d}
"context for subproblems"
subctx::Array{SC,d}
end
function solve(model::DualTVDDModel, algorithm::DualTVDDAlgorithm)
end
# Solver Interface Notes:
#
# - a concrete subtype `<:Model` describes one model type and each instance
# represents a partical solvable problem, including model parameters and data
# - a concrete subtype `<:Algorithm` describes an algorithm type, maybe
# applicable to different models with special implementations
# - abstract subtypes `<:Model`, `<:Algorithm` may exist
# - `solve(::Model, ::Algorithm)` must lead to the same deterministic
# outcome depending only on the arguments given.
# - `init(::Model, ::Algorithm)::Context` initializes a runnable algorithm context
# - `run(::Context)` must continue on from the given context
# - `(<:Algorithm)(...)` constructors must accept keyword arguments only
abstract type Model end
abstract type Algorithm end
abstract type Context end
"min_u 1/2 * |Au-f|_2^2 + α*TV(u) + β/2 |u|_2 + γ |u|_1"
struct DualTVDDModel{U,VV} <: Model
"given data"
f::U
"forward operator"
A::VV
"total variation parameter"
α::Float64
"L2 regularization parameter"
β::Float64
"L1 regularization parameter"
γ::Float64
end
"min_p 1/2 * |div(p) - g|_B^2 + χ_{|p|<=λ}"
struct ChambolleModel{U,VV} <: Model
"given data"
g::U
"B norm operator"
B::VV
"total variation parameter"
λ::Float64
end
using Test
using BenchmarkTools
using Revise
using Pkg
Pkg.activate(".")
using LinearAlgebra
using Revise
using DualTVDD
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment