From b618597f51a1ece4f8e9eebb825611ad69595659 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Mon, 11 May 2020 16:15:02 +0200
Subject: [PATCH] implement specialized chambolle

---
 Manifest.toml    |  34 ++++++++++++
 Project.toml     |   5 ++
 src/DualTVDD.jl  |   4 +-
 src/chambolle.jl | 137 +++++++++++++++++++++++++++++++++++++++++++++++
 src/dualtvdd.jl  |  24 +++++++++
 src/types.jl     |  40 ++++++++++++++
 test/repl.jl     |   5 +-
 7 files changed, 247 insertions(+), 2 deletions(-)
 create mode 100644 Manifest.toml
 create mode 100644 src/chambolle.jl
 create mode 100644 src/dualtvdd.jl
 create mode 100644 src/types.jl

diff --git a/Manifest.toml b/Manifest.toml
new file mode 100644
index 0000000..9a8fa70
--- /dev/null
+++ b/Manifest.toml
@@ -0,0 +1,34 @@
+# This file is machine-generated - editing it directly is not advised
+
+[[Libdl]]
+uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
+
+[[LinearAlgebra]]
+deps = ["Libdl"]
+uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+
+[[Random]]
+deps = ["Serialization"]
+uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+
+[[Serialization]]
+uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+
+[[SparseArrays]]
+deps = ["LinearAlgebra", "Random"]
+uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+
+[[StaticArrays]]
+deps = ["LinearAlgebra", "Random", "Statistics"]
+git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0"
+uuid = "90137ffa-7385-5640-81b9-e52037218182"
+version = "0.12.3"
+
+[[StaticKernels]]
+path = "/home/stev47/.julia/dev/StaticKernels"
+uuid = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7"
+version = "0.4.0"
+
+[[Statistics]]
+deps = ["LinearAlgebra", "SparseArrays"]
+uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
diff --git a/Project.toml b/Project.toml
index b60374f..4ea8ab6 100644
--- a/Project.toml
+++ b/Project.toml
@@ -3,6 +3,11 @@ uuid = "93adc0ee-851f-4b8b-8bf8-c8a87ded093b"
 authors = ["Stephan Hilb <stephan@ecshi.net>"]
 version = "0.1.0"
 
+[deps]
+LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
+StaticKernels = "4c63dfa8-a427-4548-bd2f-4c19e87a7dc7"
+
 [compat]
 julia = "1.4.0"
 
diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index d1742f1..78038b7 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -1,5 +1,7 @@
 module DualTVDD
 
-greet() = print("Hello World!")
+include("types.jl")
+include("chambolle.jl")
+
 
 end # module
diff --git a/src/chambolle.jl b/src/chambolle.jl
new file mode 100644
index 0000000..4f7fa7b
--- /dev/null
+++ b/src/chambolle.jl
@@ -0,0 +1,137 @@
+# Chambolle, A.
+# An Algorithm for Total Variation Minimization and Applications.
+# Journal of Mathematical Imaging and Vision 20, 89–97 (2004).
+# https://doi.org/10.1023/B:JMIV.0000011325.36760.1e
+#
+# Implementation Notes:
+# - TV-parameter α instead of λ
+# - feasibility constraint |p|<=α instead of |p|<=1
+# - B is introduced, the original Chambolle algorithm has B=Id
+
+using Base: @_inline_meta
+using StaticKernels
+using StaticArrays: SVector
+using LinearAlgebra
+
+struct ChambolleAlgorithm <: Algorithm
+    "fixed point inertia parameter"
+    τ::Float64
+    function ChambolleAlgorithm(; τ=1/4)
+        return new(τ)
+    end
+end
+
+struct ChambolleContext{M,A,G,T,R,S,Sv,K1,K2} <: Context
+    "model data"
+    model::M
+    "algorithm data"
+    algorithm::A
+    "matrix view on model.f"
+    g::G
+    "matrix view on pv"
+    p::T
+    "matrix view on rv"
+    r::R
+    "matrix view on sv"
+    s::S
+    "dual variable as vector"
+    rv::Sv
+    "scalar temporary"
+    sv::Sv
+    "div(p) + g kernel"
+    k1::K1
+    "(p + τ*grad(q))/(1 + τ/α|grad(q)|) kernel"
+    k2::K2
+end
+
+function init(md::ChambolleModel, alg::ChambolleAlgorithm)
+    d = ndims(md.g)
+    g = extend(md.g, StaticKernels.ExtensionNothing())
+    pv = zeros(d * length(md.g))
+    rv = zeros(length(md.g))
+    sv = zero(rv)
+    p = extend(reshape(reinterpret(SVector{d,Float64}, pv), size(md.g)), StaticKernels.ExtensionNothing())
+    r = reshape(reinterpret(Float64, rv), size(md.g))
+    s = extend(reshape(reinterpret(Float64, sv), size(md.g)), StaticKernels.ExtensionConstant(0.))
+
+    z = zero(CartesianIndex{d})
+    @inline kf1(pw, gw) = @inbounds divergence(pw) + gw[z]
+    k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
+
+    @inline function kf2(pw, sw)
+        sgrad = alg.τ * gradient(sw)
+        return @inbounds (pw[z] + sgrad) / (1 + norm(sgrad) / md.λ)
+    end
+    k2 = Kernel{ntuple(_->0:1, d)}(kf2)
+
+    return ChambolleContext(md, alg, g, p, r, s, rv, sv, k1, k2)
+end
+
+@generated function gradient(w::StaticKernels.Window{<:Any,N}) where N
+    i0 = ntuple(_->0, N)
+    i1(k) = ntuple(i->Int(k==i), N)
+
+    wi = (:(w[$(i1(k)...)] - w[$(i0...)]) for k in 1:N)
+    return quote
+        Base.@_inline_meta
+        return @inbounds SVector($(wi...))
+    end
+end
+
+@generated function divergence(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
+    i0 = ntuple(_->0, N)
+    i1(k) = ntuple(i->Int(k==i), N)
+
+    wi = (:((isnothing(w[$(i1(k)...)]) ? zero(T) : w[$(i0...)][$k]) -
+            (isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N)
+    return quote
+        Base.@_inline_meta
+        return @inbounds +($(wi...))
+    end
+end
+
+#  x = (A'*A + β*I) \ (FD.divergence(grid, p) .+ A'*f)
+#  q = FD.gradient(grid, x)
+
+function step!(ctx::ChambolleContext)
+    # r = div(p) + g
+    map!(ctx.k1, ctx.r, ctx.p, ctx.g)
+    # s = B * r
+    mul!(ctx.sv, ctx.model.B, ctx.rv)
+    # p = (p + τ*grad(s)) / (1 + τ/α|grad(s)|)
+    map!(ctx.k2, ctx.p, ctx.p, ctx.s)
+end
+
+function recover_u!(ctx)
+    # r = div(p) + g
+    map!(ctx.k1, ctx.r, ctx.p, ctx.g)
+    # s = B * r
+    mul!(ctx.sv, ctx.model.B, ctx.rv)
+end
+
+#
+#function solve(md::ROFModel, alg::Chambolle;
+#               maxiters = 1000,
+#               save_log = false)
+#
+#    p = zeros(ndims(md.g), size(md.g)...)
+#    q = zero(p)
+#    s = zero(md.g)
+#    div = div_op(p)
+#    grad = grad_op(md.g)
+#    ctx = ChambolleContext(p, q, s, div, grad)
+#
+#    log = Log()
+#
+#    for i in 1:maxiters
+#        step!(ctx, md, alg)
+#
+#        if save_log
+#            recover_u!(ctx, md, alg)
+#            save_log && push!(log.energy, energy2(md, ctx.s))
+#        end
+#    end
+#
+#    recover_u!(ctx, md, alg)
+#    return Solution(ctx.s, md, alg, ctx, log)
+#end
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
new file mode 100644
index 0000000..0f5a781
--- /dev/null
+++ b/src/dualtvdd.jl
@@ -0,0 +1,24 @@
+struct DualTVDDAlgorithm{d} <: Algorithm
+    "number of subdomains in each dimension"
+    M::NTuple{d,Int}
+    "inertia parameter"
+    σ::Float64
+    function DualTVDDAlgorithm(; M, σ)
+        return new{length(M)}(M, σ)
+    end
+end
+
+struct DualTVDDContext{d,U,V,Vview,SC}
+    "global dual optimization variable"
+    p::V
+    "local views on p per subdomain"
+    pviews::Array{Vview,d}
+    "data for subproblems"
+    g::Array{U,d}
+    "context for subproblems"
+    subctx::Array{SC,d}
+end
+
+function solve(model::DualTVDDModel, algorithm::DualTVDDAlgorithm)
+
+end
diff --git a/src/types.jl b/src/types.jl
new file mode 100644
index 0000000..2db8925
--- /dev/null
+++ b/src/types.jl
@@ -0,0 +1,40 @@
+# Solver Interface Notes:
+#
+# - a concrete subtype `<:Model` describes one model type and each instance
+#   represents a partical solvable problem, including model parameters and data
+# - a concrete subtype `<:Algorithm` describes an algorithm type, maybe
+#   applicable to different models with special implementations
+# - abstract subtypes `<:Model`, `<:Algorithm` may exist
+# - `solve(::Model, ::Algorithm)` must lead to the same deterministic
+#   outcome depending only on the arguments given.
+# - `init(::Model, ::Algorithm)::Context` initializes a runnable algorithm context
+# - `run(::Context)` must continue on from the given context
+# - `(<:Algorithm)(...)` constructors must accept keyword arguments only
+
+abstract type Model end
+abstract type Algorithm end
+abstract type Context end
+
+"min_u  1/2 * |Au-f|_2^2 + α*TV(u) + β/2 |u|_2 + γ |u|_1"
+struct DualTVDDModel{U,VV} <: Model
+    "given data"
+    f::U
+    "forward operator"
+    A::VV
+    "total variation parameter"
+    α::Float64
+    "L2 regularization parameter"
+    β::Float64
+    "L1 regularization parameter"
+    γ::Float64
+end
+
+"min_p  1/2 * |div(p) - g|_B^2 + χ_{|p|<=λ}"
+struct ChambolleModel{U,VV} <: Model
+    "given data"
+    g::U
+    "B norm operator"
+    B::VV
+    "total variation parameter"
+    λ::Float64
+end
diff --git a/test/repl.jl b/test/repl.jl
index bbeea48..e34dbc4 100644
--- a/test/repl.jl
+++ b/test/repl.jl
@@ -1,6 +1,9 @@
+using Test
 using BenchmarkTools
-using Revise
 using Pkg
 Pkg.activate(".")
 
+using LinearAlgebra
+
+using Revise
 using DualTVDD
-- 
GitLab