From 91eb8523b2ff13c034c5ed2cd0aad1dd975c1de7 Mon Sep 17 00:00:00 2001
From: Stephan Hilb <stephan@ecshi.net>
Date: Wed, 19 Aug 2020 10:02:34 +0200
Subject: [PATCH] reworked numerical examples

---
 src/DualTVDD.jl |   3 +
 src/common.jl   |   4 +-
 src/dualtvdd.jl |  23 ++++----
 src/nstep.jl    |  24 ++++++++
 src/problems.jl |  24 ++++++--
 src/tvnewton.jl | 154 ++++++++++++++++++++++++++++++++++++++++++++++++
 test/optflow.jl |   6 +-
 7 files changed, 218 insertions(+), 20 deletions(-)
 create mode 100644 src/nstep.jl
 create mode 100644 src/tvnewton.jl

diff --git a/src/DualTVDD.jl b/src/DualTVDD.jl
index 220d7c8..98a625f 100644
--- a/src/DualTVDD.jl
+++ b/src/DualTVDD.jl
@@ -4,9 +4,12 @@ export init, step!, fetch, run
 
 include("common.jl")
 include("problems.jl")
+
+include("nstep.jl")
 include("chambolle.jl")
 include("dualtvdd.jl")
 include("projgrad.jl")
+#include("tvnewton.jl")
 
 #using Plots: heatmap
 #
diff --git a/src/common.jl b/src/common.jl
index 431f0ec..60958d8 100644
--- a/src/common.jl
+++ b/src/common.jl
@@ -42,8 +42,8 @@ abstract type Algorithm{P<:Problem} end
 """
     <:State
 
-Algorithm state and storage containing sufficient information to act as a
-checkpoint for continuing the corresponding algorithm at that point.
+The algorithm workspace. Together with the algorithm it contains sufficient
+information to act as a checkpoint for continuing the algorithm at that point.
 """
 abstract type State end
 
diff --git a/src/dualtvdd.jl b/src/dualtvdd.jl
index 77edede..9745045 100644
--- a/src/dualtvdd.jl
+++ b/src/dualtvdd.jl
@@ -11,8 +11,10 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P}
     σ::Float64
     "number of inner iterations"
     ninner::Int
-    function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=1/4, ninner=10)
-        return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner)
+    "prob -> Algorithm(::Problem, ...)"
+    subalg::Function
+    function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=1/4, ninner=10, subalg=x->ProjectedGradient(x))
+        return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg)
     end
 end
 
@@ -52,9 +54,7 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
     subprobs = [DualTVL1ROFOpProblem(subg[i], op_restrict(alg.problem.B, ax, subax[i]), subλ[i])
         for i in CartesianIndices(subax)]
 
-    # TODO: make inner algorithm a parameter
-    #subalg = [ProjGradAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
-    subalg = [ChambolleAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
+    subalg = [alg.subalg(subprobs[i]) for i in CartesianIndices(subax)]
 
     subctx = [init(x) for x in subalg]
 
@@ -77,23 +77,24 @@ function step!(ctx::DualTVDDState)
     overlap = ctx.algorithm.overlap
 
     # call run! on each cell (this can be threaded)
-    for (i, sax) in pairs(ctx.subax)
+    Threads.@threads for i in eachindex(ctx.subax)
+        sax = ctx.subax[i]
         li = LinearIndices(ctx.subax)[i]
         sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround
         sq = ctx.q[i] # julia-bug workaround
 
-        sg .= alg.problem.g[sax...]
+        sg .= view(alg.problem.g, sax...)
         if alg.parallel
-            sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* ctx.p[sax...]
+            sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...)
         else
             sq .= Ref(zero(eltype(sq)))
             # contributions from previous domains
             for (lj, saxj) in enumerate(ctx.subax)
                 ids, idsi, idsj = intersectin(CartesianIndices(sax), CartesianIndices(saxj))
                 if lj < li
-                    sq[idsi] .+= ctx.subctx[lj].p[idsj]
+                    sq[idsi] .+= view(ctx.subctx[lj].p, idsj)
                 elseif lj > li
-                    sq[idsi] .+= theta.(Ref(ax), Ref(saxj), Ref(overlap), ids) .* ctx.p[ids]
+                    sq[idsi] .+= theta.(Ref(ax), Ref(saxj), Ref(overlap), ids) .* view(ctx.p, ids)
                 end
             end
         end
@@ -111,7 +112,7 @@ function step!(ctx::DualTVDDState)
     # aggregate (not thread-safe!)
     ctx.p .*= 1 - σ
     for (i, sax) in pairs(ctx.subax)
-        ctx.p[sax...] .+= σ .* ctx.subctx[i].p
+        view(ctx.p, sax...) .+= σ .* ctx.subctx[i].p
     end
 
     return ctx
diff --git a/src/nstep.jl b/src/nstep.jl
new file mode 100644
index 0000000..9671260
--- /dev/null
+++ b/src/nstep.jl
@@ -0,0 +1,24 @@
+struct NStepAlgorithm{P,A} <: Algorithm{P}
+    problem::P
+    parent::A
+    n::Int
+end
+
+NStepAlgorithm(alg::Algorithm, n::Int) = NStepAlgorithm(alg.problem, alg, n)
+
+struct NStepState{A,S} <: State
+    algorithm::A
+    parent::S
+end
+
+Base.parent(x::NStepAlgorithm) = x.parent
+Base.parent(x::NStepState) = x.parent
+
+init(x::NStepAlgorithm) = NStepState(x, init(parent(x)))
+function step!(st::NStepState)
+    for i in 1:st.algorithm.n
+        step!(parent(st))
+    end
+    return parent(st)
+end
+fetch(x::NStepState) = fetch(parent(x))
diff --git a/src/problems.jl b/src/problems.jl
index 63957f4..d37dfa3 100644
--- a/src/problems.jl
+++ b/src/problems.jl
@@ -1,4 +1,5 @@
-using LinearAlgebra: UniformScaling, Diagonal
+using LinearAlgebra: UniformScaling, Diagonal, opnorm
+using StaticKernels: ExtensionNothing
 
 #"min_u 1/2 * |Au-f|_2^2 + λ*TV(u) + β/2 |u|_2 + γ |u|_1"
 "min_p 1/2 * |p₂ - div(p₁) - g|_B^2 + χ_{|p₁|≤λ} + χ_{|p₂|≤γ}"
@@ -16,6 +17,21 @@ end
 DualTVL1ROFOpProblem(g, B, λ) = DualTVL1ROFOpProblem(g, B, λ, nothing)
 DualTVL1ROFOpProblem(g, B, λ::Real) = DualTVL1ROFOpProblem(g, B, fill!(similar(g, typeof(λ)), λ))
 
-normB(p::DualTVL1ROFOpProblem{<:Any,<:UniformScaling}) = p.B.λ * sqrt(length(p.g))
-normB(p::DualTVL1ROFOpProblem) = norm(p.B)
-normB(p::DualTVL1ROFOpProblem{<:Any,<:Diagonal}) = sqrt(sum(norm.(diag(p.B)) .^ 2))
+function energy(p, prob::DualTVL1ROFOpProblem)
+    d = ndims(p)
+
+    @inline kfΛ(w) = @inbounds divergence(w) + prob.g[w.position]
+    kΛ = Kernel{ntuple(_->-1:1, d)}(kfΛ)
+
+    # v = div(p) + g
+    v = map(kΛ, extend(p, ExtensionNothing()))
+
+    # |v|_B^2 / 2
+    u = prob.B * vec(v)
+    return sum(dot.(u, vec(v))) / 2
+end
+
+# operator norm of B
+normB(p::DualTVL1ROFOpProblem{<:Any,<:UniformScaling}) = p.B.λ
+normB(p::DualTVL1ROFOpProblem{<:Any,<:Diagonal}) = maximum(opnorm.(diag(p.B)))
+normB(p::DualTVL1ROFOpProblem) = opnorm(p.B)
diff --git a/src/tvnewton.jl b/src/tvnewton.jl
new file mode 100644
index 0000000..4f98b0f
--- /dev/null
+++ b/src/tvnewton.jl
@@ -0,0 +1,154 @@
+using LinearAlgebra: dot, pinv, norm, cond, eigmax, eigmin, svdvals
+
+import ForwardDiff
+using ForwardDiff: jacobian, derivative, gradient
+using Krylov
+using LinearOperators: LinearOperator
+using StaticArrays
+using StaticKernels: Kernel, ExtensionReplicate, ExtensionNothing
+
+
+struct TVNewtonAlgorithm{P} <: Algorithm{P}
+    problem::P
+
+    function TVNewtonAlgorithm(problem)
+        return new{typeof(problem)}(problem)
+    end
+end
+
+struct TVNewtonState{A,Tx,Tu,Tn,K1,K2}
+    algorithm::A
+
+    "combined primal/dual variable"
+    x::Tx
+    "primal view on `x`"
+    u::Tu
+    "dual view on `x`"
+    n::Tn
+
+    kdiv::K1
+    kgrad::K2
+end
+
+function init(alg::TVNewtonAlgorithm{<:DualTVL1ROFOpProblem})
+    g = alg.problem.g
+    λ = alg.problem.λ
+    sz = size(g)
+    d = ndims(g)
+    k = eltype(g) <: StaticArray ? length(eltype(g)) : 1
+
+    len = k * prod(sz), d * k * prod(sz)
+    x = Vector{eltype(eltype(g))}(undef, sum(len))
+    u, n = interpret(x, sz)
+
+    #x .= rand(size(x))
+    u .= g
+    n .= Ref(zero(eltype(n)))
+    #n .= rand(eltype(n), size(n))
+
+    kgrad = Kernel{ntuple(_->0:1, d)}(gradient)
+    kdiv = Kernel{ntuple(_->-1:1, d)}(mydivergence)
+
+    return TVNewtonState(alg, x, u, n, kdiv, kgrad)
+end
+
+function interpret(x, sz)
+    d = length(sz)
+    n = length(x)
+    k = n ÷ prod(sz) ÷ (d + 1)
+
+    T = eltype(x)
+    Tu = k == 1 ? T : SVector{k,T}
+    Tn = SVector{d,Tu}
+
+    len = k * prod(sz), d * k * prod(sz)
+
+    uview = extend(reshape(reinterpret(Tu, @view(x[begin:len[1]])), sz), ExtensionReplicate())
+    nview = extend(reshape(reinterpret(Tn, @view(x[len[1]+1:end])), sz), ExtensionNothing())
+
+    return uview, nview
+end
+
+
+function step!(st::TVNewtonState)
+    alg = st.algorithm
+    prob = alg.problem
+
+    g, λ, B = prob.g, prob.λ, prob.B
+    x = st.x
+
+    function f(x)
+        u, n = interpret(x, size(prob.g))
+        ugrad = map(st.kgrad, u)
+        ndiv = map(st.kdiv, n)
+
+        fu = u .- reshape(B * vec(λ .* ndiv .+ prob.g), size(u))
+        #fu = reshape(B \ vec(u), size(u)) .- prob.g .- λ .* ndiv
+        #fn = sqrt.(norm2.(ugrad) .+ floatmin()) .* ugrad .- norm2.(ugrad) .* n
+        fn = ugrad .- sqrt.(norm2.(ugrad) .+ floatmin()) .* n
+
+        T = eltype(x)
+        return vcat(reinterpret(T, vec(fu)), reinterpret(T, vec(fn)))
+    end
+
+    fx = f(x)
+    A = jacobian(f, x)
+    #any(isnan, A) && error("jacobian contains NaN")
+
+    Amap = LinearOperator(eltype(x), length(fx), length(x), false, false,
+        y -> jvp(f, x, y), y -> vjp(f, x, y), y -> vjp(f, x, y))
+
+    #println("cond(A) = $(cond(A))")
+    #println("svdvals(A) = $(svdvals(A))")
+    #dir = -dqgmres(Amap, fx)[1]
+    dir = -cgs(Amap, fx)[1]
+    #dir = -cg_lanczos(Amap, fx)[1]
+    #dir = -lsqr(Amap, fx)
+    #dir = -(A \ fx)
+    #dir = -(pinv(A) * fx)
+    τ = 1.
+    y = x + τ * dir
+    project!(y, size(g))
+    #while norm(f(y)) >= norm(fx)
+    #    τ /= 2
+    #    y = x + τ * dir
+    #    project!(y, size(g))
+    #    τ == 0 && break
+    #end
+    println("residual norm : $(norm(f(y))/length(fx)) with τ = $τ")
+    st.x .= y
+
+    return st
+end
+
+fetch(ctx::TVNewtonState) = ctx.algorithm.problem.λ .* ctx.n
+
+@generated function mydivergence(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
+    i0 = ntuple(_->0, N)
+    i1(k) = ntuple(i->Int(k==i), N)
+
+    wi = (:((isnothing(w[$(i1(k)...)]) ? zero(T) : w[$(i0...)][$k]) -
+            (isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N)
+    return quote
+        Base.@_inline_meta
+        return @inbounds +($(wi...))
+    end
+end
+
+relu(x) = x < 0 ? 0 : x
+norm2(x) = dot(x,x)
+
+function project!(x, sz)
+    _, n = interpret(x, sz)
+    fp(y) = norm(y) > 1 ? y/norm(y) : y
+    n .= fp.(n)
+end
+
+jvp(f, x, v) = ForwardDiff.derivative(ε -> f(x .+ ε .* v), 0.)
+vjp(f, x, v) = ForwardDiff.gradient(φ -> dot(f(φ), v), x)
+
+function newton(f, x, g)
+end
+
+using Revise
+Revise.track(@__FILE__)
diff --git a/test/optflow.jl b/test/optflow.jl
index 63742f2..da5c449 100644
--- a/test/optflow.jl
+++ b/test/optflow.jl
@@ -28,12 +28,12 @@ end
 function run_optflow(f0, f1, λ=0.01, β=0.001)
     prob = OptFlowProblem(f0, f1, λ, β)
     #alg = ChambolleAlgorithm(prob, τ=1/10000)
-    alg = ProjGradAlgorithm(prob, τ=1/10000)
+    #alg = ProjGradAlgorithm(prob, τ=1/10000)
     #alg = ProjGradAlgorithm(prob)
-    #alg = DualTVDD.DualTVDDAlgorithm(prob, M=(2,2), overlap=(4,4))
+    alg = DualTVDD.DualTVDDAlgorithm(prob, M=(2,2), overlap=(4,4))
 
     ctx = init(alg)
-    for i in 1:10000
+    for i in 1:1000
         step!(ctx)
     end
     plot_optflow(ctx)
-- 
GitLab