Skip to content
Snippets Groups Projects
Commit 91eb8523 authored by Stephan Hilb's avatar Stephan Hilb
Browse files

reworked numerical examples

parent ce1fb5d8
No related branches found
No related tags found
No related merge requests found
...@@ -4,9 +4,12 @@ export init, step!, fetch, run ...@@ -4,9 +4,12 @@ export init, step!, fetch, run
include("common.jl") include("common.jl")
include("problems.jl") include("problems.jl")
include("nstep.jl")
include("chambolle.jl") include("chambolle.jl")
include("dualtvdd.jl") include("dualtvdd.jl")
include("projgrad.jl") include("projgrad.jl")
#include("tvnewton.jl")
#using Plots: heatmap #using Plots: heatmap
# #
......
...@@ -42,8 +42,8 @@ abstract type Algorithm{P<:Problem} end ...@@ -42,8 +42,8 @@ abstract type Algorithm{P<:Problem} end
""" """
<:State <:State
Algorithm state and storage containing sufficient information to act as a The algorithm workspace. Together with the algorithm it contains sufficient
checkpoint for continuing the corresponding algorithm at that point. information to act as a checkpoint for continuing the algorithm at that point.
""" """
abstract type State end abstract type State end
......
...@@ -11,8 +11,10 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P} ...@@ -11,8 +11,10 @@ struct DualTVDDAlgorithm{P,d} <: Algorithm{P}
σ::Float64 σ::Float64
"number of inner iterations" "number of inner iterations"
ninner::Int ninner::Int
function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=1/4, ninner=10) "prob -> Algorithm(::Problem, ...)"
return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner) subalg::Function
function DualTVDDAlgorithm(problem; M, overlap, parallel=true, σ=1/4, ninner=10, subalg=x->ProjectedGradient(x))
return new{typeof(problem), length(M)}(problem, M, overlap, parallel, σ, ninner, subalg)
end end
end end
...@@ -52,9 +54,7 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem}) ...@@ -52,9 +54,7 @@ function init(alg::DualTVDDAlgorithm{<:DualTVL1ROFOpProblem})
subprobs = [DualTVL1ROFOpProblem(subg[i], op_restrict(alg.problem.B, ax, subax[i]), subλ[i]) subprobs = [DualTVL1ROFOpProblem(subg[i], op_restrict(alg.problem.B, ax, subax[i]), subλ[i])
for i in CartesianIndices(subax)] for i in CartesianIndices(subax)]
# TODO: make inner algorithm a parameter subalg = [alg.subalg(subprobs[i]) for i in CartesianIndices(subax)]
#subalg = [ProjGradAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
subalg = [ChambolleAlgorithm(subprobs[i]) for i in CartesianIndices(subax)]
subctx = [init(x) for x in subalg] subctx = [init(x) for x in subalg]
...@@ -77,23 +77,24 @@ function step!(ctx::DualTVDDState) ...@@ -77,23 +77,24 @@ function step!(ctx::DualTVDDState)
overlap = ctx.algorithm.overlap overlap = ctx.algorithm.overlap
# call run! on each cell (this can be threaded) # call run! on each cell (this can be threaded)
for (i, sax) in pairs(ctx.subax) Threads.@threads for i in eachindex(ctx.subax)
sax = ctx.subax[i]
li = LinearIndices(ctx.subax)[i] li = LinearIndices(ctx.subax)[i]
sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround sg = ctx.subctx[i].algorithm.problem.g # julia-bug workaround
sq = ctx.q[i] # julia-bug workaround sq = ctx.q[i] # julia-bug workaround
sg .= alg.problem.g[sax...] sg .= view(alg.problem.g, sax...)
if alg.parallel if alg.parallel
sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* ctx.p[sax...] sq .= (1 .- theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(sax))) .* view(ctx.p, sax...)
else else
sq .= Ref(zero(eltype(sq))) sq .= Ref(zero(eltype(sq)))
# contributions from previous domains # contributions from previous domains
for (lj, saxj) in enumerate(ctx.subax) for (lj, saxj) in enumerate(ctx.subax)
ids, idsi, idsj = intersectin(CartesianIndices(sax), CartesianIndices(saxj)) ids, idsi, idsj = intersectin(CartesianIndices(sax), CartesianIndices(saxj))
if lj < li if lj < li
sq[idsi] .+= ctx.subctx[lj].p[idsj] sq[idsi] .+= view(ctx.subctx[lj].p, idsj)
elseif lj > li elseif lj > li
sq[idsi] .+= theta.(Ref(ax), Ref(saxj), Ref(overlap), ids) .* ctx.p[ids] sq[idsi] .+= theta.(Ref(ax), Ref(saxj), Ref(overlap), ids) .* view(ctx.p, ids)
end end
end end
end end
...@@ -111,7 +112,7 @@ function step!(ctx::DualTVDDState) ...@@ -111,7 +112,7 @@ function step!(ctx::DualTVDDState)
# aggregate (not thread-safe!) # aggregate (not thread-safe!)
ctx.p .*= 1 - σ ctx.p .*= 1 - σ
for (i, sax) in pairs(ctx.subax) for (i, sax) in pairs(ctx.subax)
ctx.p[sax...] .+= σ .* ctx.subctx[i].p view(ctx.p, sax...) .+= σ .* ctx.subctx[i].p
end end
return ctx return ctx
......
struct NStepAlgorithm{P,A} <: Algorithm{P}
problem::P
parent::A
n::Int
end
NStepAlgorithm(alg::Algorithm, n::Int) = NStepAlgorithm(alg.problem, alg, n)
struct NStepState{A,S} <: State
algorithm::A
parent::S
end
Base.parent(x::NStepAlgorithm) = x.parent
Base.parent(x::NStepState) = x.parent
init(x::NStepAlgorithm) = NStepState(x, init(parent(x)))
function step!(st::NStepState)
for i in 1:st.algorithm.n
step!(parent(st))
end
return parent(st)
end
fetch(x::NStepState) = fetch(parent(x))
using LinearAlgebra: UniformScaling, Diagonal using LinearAlgebra: UniformScaling, Diagonal, opnorm
using StaticKernels: ExtensionNothing
#"min_u 1/2 * |Au-f|_2^2 + λ*TV(u) + β/2 |u|_2 + γ |u|_1" #"min_u 1/2 * |Au-f|_2^2 + λ*TV(u) + β/2 |u|_2 + γ |u|_1"
"min_p 1/2 * |p₂ - div(p₁) - g|_B^2 + χ_{|p₁|≤λ} + χ_{|p₂|≤γ}" "min_p 1/2 * |p₂ - div(p₁) - g|_B^2 + χ_{|p₁|≤λ} + χ_{|p₂|≤γ}"
...@@ -16,6 +17,21 @@ end ...@@ -16,6 +17,21 @@ end
DualTVL1ROFOpProblem(g, B, λ) = DualTVL1ROFOpProblem(g, B, λ, nothing) DualTVL1ROFOpProblem(g, B, λ) = DualTVL1ROFOpProblem(g, B, λ, nothing)
DualTVL1ROFOpProblem(g, B, λ::Real) = DualTVL1ROFOpProblem(g, B, fill!(similar(g, typeof(λ)), λ)) DualTVL1ROFOpProblem(g, B, λ::Real) = DualTVL1ROFOpProblem(g, B, fill!(similar(g, typeof(λ)), λ))
normB(p::DualTVL1ROFOpProblem{<:Any,<:UniformScaling}) = p.B.λ * sqrt(length(p.g)) function energy(p, prob::DualTVL1ROFOpProblem)
normB(p::DualTVL1ROFOpProblem) = norm(p.B) d = ndims(p)
normB(p::DualTVL1ROFOpProblem{<:Any,<:Diagonal}) = sqrt(sum(norm.(diag(p.B)) .^ 2))
@inline kfΛ(w) = @inbounds divergence(w) + prob.g[w.position]
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
# v = div(p) + g
v = map(, extend(p, ExtensionNothing()))
# |v|_B^2 / 2
u = prob.B * vec(v)
return sum(dot.(u, vec(v))) / 2
end
# operator norm of B
normB(p::DualTVL1ROFOpProblem{<:Any,<:UniformScaling}) = p.B.λ
normB(p::DualTVL1ROFOpProblem{<:Any,<:Diagonal}) = maximum(opnorm.(diag(p.B)))
normB(p::DualTVL1ROFOpProblem) = opnorm(p.B)
using LinearAlgebra: dot, pinv, norm, cond, eigmax, eigmin, svdvals
import ForwardDiff
using ForwardDiff: jacobian, derivative, gradient
using Krylov
using LinearOperators: LinearOperator
using StaticArrays
using StaticKernels: Kernel, ExtensionReplicate, ExtensionNothing
struct TVNewtonAlgorithm{P} <: Algorithm{P}
problem::P
function TVNewtonAlgorithm(problem)
return new{typeof(problem)}(problem)
end
end
struct TVNewtonState{A,Tx,Tu,Tn,K1,K2}
algorithm::A
"combined primal/dual variable"
x::Tx
"primal view on `x`"
u::Tu
"dual view on `x`"
n::Tn
kdiv::K1
kgrad::K2
end
function init(alg::TVNewtonAlgorithm{<:DualTVL1ROFOpProblem})
g = alg.problem.g
λ = alg.problem.λ
sz = size(g)
d = ndims(g)
k = eltype(g) <: StaticArray ? length(eltype(g)) : 1
len = k * prod(sz), d * k * prod(sz)
x = Vector{eltype(eltype(g))}(undef, sum(len))
u, n = interpret(x, sz)
#x .= rand(size(x))
u .= g
n .= Ref(zero(eltype(n)))
#n .= rand(eltype(n), size(n))
kgrad = Kernel{ntuple(_->0:1, d)}(gradient)
kdiv = Kernel{ntuple(_->-1:1, d)}(mydivergence)
return TVNewtonState(alg, x, u, n, kdiv, kgrad)
end
function interpret(x, sz)
d = length(sz)
n = length(x)
k = n ÷ prod(sz) ÷ (d + 1)
T = eltype(x)
Tu = k == 1 ? T : SVector{k,T}
Tn = SVector{d,Tu}
len = k * prod(sz), d * k * prod(sz)
uview = extend(reshape(reinterpret(Tu, @view(x[begin:len[1]])), sz), ExtensionReplicate())
nview = extend(reshape(reinterpret(Tn, @view(x[len[1]+1:end])), sz), ExtensionNothing())
return uview, nview
end
function step!(st::TVNewtonState)
alg = st.algorithm
prob = alg.problem
g, λ, B = prob.g, prob.λ, prob.B
x = st.x
function f(x)
u, n = interpret(x, size(prob.g))
ugrad = map(st.kgrad, u)
ndiv = map(st.kdiv, n)
fu = u .- reshape(B * vec(λ .* ndiv .+ prob.g), size(u))
#fu = reshape(B \ vec(u), size(u)) .- prob.g .- λ .* ndiv
#fn = sqrt.(norm2.(ugrad) .+ floatmin()) .* ugrad .- norm2.(ugrad) .* n
fn = ugrad .- sqrt.(norm2.(ugrad) .+ floatmin()) .* n
T = eltype(x)
return vcat(reinterpret(T, vec(fu)), reinterpret(T, vec(fn)))
end
fx = f(x)
A = jacobian(f, x)
#any(isnan, A) && error("jacobian contains NaN")
Amap = LinearOperator(eltype(x), length(fx), length(x), false, false,
y -> jvp(f, x, y), y -> vjp(f, x, y), y -> vjp(f, x, y))
#println("cond(A) = $(cond(A))")
#println("svdvals(A) = $(svdvals(A))")
#dir = -dqgmres(Amap, fx)[1]
dir = -cgs(Amap, fx)[1]
#dir = -cg_lanczos(Amap, fx)[1]
#dir = -lsqr(Amap, fx)
#dir = -(A \ fx)
#dir = -(pinv(A) * fx)
τ = 1.
y = x + τ * dir
project!(y, size(g))
#while norm(f(y)) >= norm(fx)
# τ /= 2
# y = x + τ * dir
# project!(y, size(g))
# τ == 0 && break
#end
println("residual norm : $(norm(f(y))/length(fx)) with τ = $τ")
st.x .= y
return st
end
fetch(ctx::TVNewtonState) = ctx.algorithm.problem.λ .* ctx.n
@generated function mydivergence(w::StaticKernels.Window{SVector{N,T},N}) where {N,T}
i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N)
wi = (:((isnothing(w[$(i1(k)...)]) ? zero(T) : w[$(i0...)][$k]) -
(isnothing(w[$((.-i1(k))...)]) ? zero(T) : w[$((.-i1(k))...)][$k])) for k in 1:N)
return quote
Base.@_inline_meta
return @inbounds +($(wi...))
end
end
relu(x) = x < 0 ? 0 : x
norm2(x) = dot(x,x)
function project!(x, sz)
_, n = interpret(x, sz)
fp(y) = norm(y) > 1 ? y/norm(y) : y
n .= fp.(n)
end
jvp(f, x, v) = ForwardDiff.derivative(ε -> f(x .+ ε .* v), 0.)
vjp(f, x, v) = ForwardDiff.gradient(φ -> dot(f(φ), v), x)
function newton(f, x, g)
end
using Revise
Revise.track(@__FILE__)
...@@ -28,12 +28,12 @@ end ...@@ -28,12 +28,12 @@ end
function run_optflow(f0, f1, λ=0.01, β=0.001) function run_optflow(f0, f1, λ=0.01, β=0.001)
prob = OptFlowProblem(f0, f1, λ, β) prob = OptFlowProblem(f0, f1, λ, β)
#alg = ChambolleAlgorithm(prob, τ=1/10000) #alg = ChambolleAlgorithm(prob, τ=1/10000)
alg = ProjGradAlgorithm(prob, τ=1/10000) #alg = ProjGradAlgorithm(prob, τ=1/10000)
#alg = ProjGradAlgorithm(prob) #alg = ProjGradAlgorithm(prob)
#alg = DualTVDD.DualTVDDAlgorithm(prob, M=(2,2), overlap=(4,4)) alg = DualTVDD.DualTVDDAlgorithm(prob, M=(2,2), overlap=(4,4))
ctx = init(alg) ctx = init(alg)
for i in 1:10000 for i in 1:1000
step!(ctx) step!(ctx)
end end
plot_optflow(ctx) plot_optflow(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment