Skip to content
Snippets Groups Projects
Commit de855f9a authored by Stephan Hilb's avatar Stephan Hilb
Browse files

add projected gradient

parent e2d9a050
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ module DualTVDD ...@@ -5,6 +5,7 @@ module DualTVDD
include("types.jl") include("types.jl")
include("chambolle.jl") include("chambolle.jl")
include("dualtvdd.jl") include("dualtvdd.jl")
include("projgrad.jl")
using Makie: heatmap using Makie: heatmap
...@@ -47,8 +48,8 @@ function rundd() ...@@ -47,8 +48,8 @@ function rundd()
f = zeros(2,2) f = zeros(2,2)
f[1,:] .= 1 f[1,:] .= 1
#g = [0. 2; 1 0.] #g = [0. 2; 1 0.]
A = diagm(vcat(fill(1, length(f)÷2), fill(1/1000, length(f)÷2))) #A = diagm(vcat(fill(1, length(f)÷2), fill(1/10, length(f)÷2)))
#A = rand(length(f), length(f)) A = rand(length(f), length(f))
display(A) display(A)
println(cond(A)) println(cond(A))
display(eigen(A)) display(eigen(A))
...@@ -61,7 +62,7 @@ function rundd() ...@@ -61,7 +62,7 @@ function rundd()
g = similar(f) g = similar(f)
vec(g) .= A' * vec(f) vec(g) .= A' * vec(f)
α = .25 α = .025
md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.) md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
alg = DualTVDD.DualTVDDAlgorithm(M=(1,1), overlap=(1,1), σ=1) alg = DualTVDD.DualTVDDAlgorithm(M=(1,1), overlap=(1,1), σ=1)
...@@ -75,7 +76,7 @@ function rundd() ...@@ -75,7 +76,7 @@ function rundd()
for i in 1:1 for i in 1:1
step!(ctx) step!(ctx)
end end
for i in 1:100000 for i in 1:1000000
step!(ctx2) step!(ctx2)
end end
...@@ -108,7 +109,47 @@ function rundd() ...@@ -108,7 +109,47 @@ function rundd()
ctx, ctx2 ctx, ctx2
end end
function energy(ctx::DualTVDDContext) function run3()
f = rand(20,20)
A = rand(length(f), length(f))
A .+= diagm(ones(length(f)))
g = reshape(A'*vec(f), size(f))
β = 0
B = inv(A'*A + β*I)
println(norm(A))
α = 0.1
# Chambolle
md = DualTVDD.OpROFModel(g, B, α)
alg = DualTVDD.ChambolleAlgorithm()
ctx = DualTVDD.init(md, alg)
# Projected Gradient
md = DualTVDD.DualTVDDModel(f, A, α, 0., 0.)
alg = DualTVDD.ProjGradAlgorithm(λ = 1/norm(A)^2)
ctx2 = DualTVDD.init(md, alg)
for i in 1:100000
step!(ctx)
step!(ctx2)
end
#display(ctx.p)
#display(ctx2.p)
display(recover_u!(ctx))
display(recover_u!(ctx2))
println(energy(ctx))
println(energy(ctx2))
ctx, ctx2
end
function energy(ctx::Union{DualTVDDContext,ProjGradContext})
d = ndims(ctx.p) d = ndims(ctx.p)
@inline kfΛ(w) = @inbounds divergence(w) @inline kfΛ(w) = @inbounds divergence(w)
......
...@@ -73,6 +73,10 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm) ...@@ -73,6 +73,10 @@ function init(md::OpROFModel, alg::ChambolleAlgorithm)
return ChambolleContext(md, alg, g, λ, p, r, s, rv, sv, k1, k2) return ChambolleContext(md, alg, g, λ, p, r, s, rv, sv, k1, k2)
end end
function reset!(ctx::ChambolleContext)
fill!(ctx.p, zero(eltype(ctx.p)))
end
@generated function gradient(w::StaticKernels.Window{<:Any,N}) where N @generated function gradient(w::StaticKernels.Window{<:Any,N}) where N
i0 = ntuple(_->0, N) i0 = ntuple(_->0, N)
i1(k) = ntuple(i->Int(k==i), N) i1(k) = ntuple(i->Int(k==i), N)
......
...@@ -95,11 +95,10 @@ function step!(ctx::DualTVDDContext) ...@@ -95,11 +95,10 @@ function step!(ctx::DualTVDDContext)
ctx.subg[i] .+= ctx.g[sax...] ctx.subg[i] .+= ctx.g[sax...]
# set sensible starting value # set sensible starting value
ctx.subctx[i].p .= Ref(zero(eltype(ctx.subctx[i].p))) reset!(ctx.subctx[i])
# precomputed: B/λ * (A'f - Λ(1-θ_i)p^n) # precomputed: B/λ * (A'f - Λ(1-θ_i)p^n)
gloc = similar(ctx.subg[i]) gloc = copy(ctx.subg[i])
vec(gloc) .= ctx.subctx[i].model.B * vec(ctx.subg[i])
# v_0 # v_0
ctx.ptmp .= theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)) .* ctx.p ctx.ptmp .= theta.(Ref(ax), Ref(sax), Ref(overlap), CartesianIndices(ctx.p)) .* ctx.p
...@@ -109,11 +108,11 @@ function step!(ctx::DualTVDDContext) ...@@ -109,11 +108,11 @@ function step!(ctx::DualTVDDContext)
subIB = I - ctx.B[vec(li[sax...]), vec(li[sax...])]./λ subIB = I - ctx.B[vec(li[sax...]), vec(li[sax...])]./λ
subB = ctx.B[vec(li[sax...]), vec(li[sax...])]./λ subB = ctx.B[vec(li[sax...]), vec(li[sax...])]./λ
for j in 1:1000000 for j in 1:10000
subΛp = map(, ctx.subctx[i].p) subΛp = map(, ctx.subctx[i].p)
vec(ctx.subg[i]) .= subIB * vec(subΛp) .+ subB * vec(gloc) vec(ctx.subg[i]) .= subIB * vec(subΛp) .+ subB * vec(gloc)
for k in 1:10000 for k in 1:1000
step!(ctx.subctx[i]) step!(ctx.subctx[i])
end end
end end
......
struct ProjGradAlgorithm <: Algorithm
"gradient step size"
λ::Float64
function ProjGradAlgorithm(; λ)
return new(λ)
end
end
struct ProjGradContext{M,A,V,W,Wv,WvWv,R,S,K1,K2}
model::M
algorithm::A
"dual optimization variable"
p::V
"precomputed A'f"
g::W
"scalar temporary 1"
rv::Wv
"scalar temporary 2"
sv::Wv
"precomputed (A'A + βI)^(-1)"
B::WvWv
"matrix view on rv"
r::R
"matrix view on sv"
s::S
k1::K1
k2::K2
end
function init(md::DualTVDDModel, alg::ProjGradAlgorithm)
# FIXME: A is assumed square
d = ndims(md.f)
ax = axes(md.f)
p = extend(zeros(SVector{d,Float64}, ax), StaticKernels.ExtensionNothing())
gtmp = reshape(md.A' * vec(md.f), size(md.f))
g = extend(gtmp, StaticKernels.ExtensionNothing())
rv = zeros(length(md.f))
sv = zeros(length(md.f))
B = inv(md.A' * md.A + md.β * I)
r = reshape(rv, ax)
s = extend(reshape(sv, ax), StaticKernels.ExtensionReplicate())
z = zero(CartesianIndex{d})
@inline kf1(pw, gw) = @inbounds -divergence(pw) - gw[z]
k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
@inline function kf2(pw, sw)
Base.@_inline_meta
q = pw[z] - alg.λ * gradient(sw)
return q / max(norm(q) / md.α, 1)
end
k2 = Kernel{ntuple(_->0:1, d)}(kf2)
return ProjGradContext(md, alg, p, g, rv, sv, B, r, s, k1, k2)
end
function step!(ctx::ProjGradContext)
# r = Λ*p - g
map!(ctx.k1, ctx.r, ctx.p, ctx.g)
# s = B * r
mul!(ctx.sv, ctx.B, ctx.rv)
# p = proj(p - λΛ's)
map!(ctx.k2, ctx.p, ctx.p, ctx.s)
end
function recover_u!(ctx::ProjGradContext)
d = ndims(ctx.g)
u = similar(ctx.g)
v = similar(ctx.g)
@inline kfΛ(w) = @inbounds divergence(w)
= Kernel{ntuple(_->-1:1, d)}(kfΛ)
# v = div(p) + A'*f
map!(, v, ctx.p) # extension: nothing
v .+= ctx.g
# u = B * v
mul!(vec(u), ctx.B, vec(v))
return u
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment