Skip to content
Snippets Groups Projects
Commit 1b15e99a authored by Stephan Hilb's avatar Stephan Hilb
Browse files

fixup surrogate algorithm

parent a98a2430
No related branches found
No related tags found
No related merge requests found
...@@ -3,16 +3,18 @@ using LinearAlgebra: I ...@@ -3,16 +3,18 @@ using LinearAlgebra: I
struct SurrogateAlgorithm{P} <: Algorithm{P} struct SurrogateAlgorithm{P} <: Algorithm{P}
problem::P problem::P
subalg::Function subalg::Function
"number of inner iterations"
ninner::Int
τ::Float64 τ::Float64
function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem; function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem;
τ=2*normB(problem), subalg=x->ProjectedGradient(x)) ninner=1,
return new{typeof(problem)}(problem, subalg, τ) τ=2*normB(problem), subalg=x->ProjGradAlgorithm(x))
return new{typeof(problem)}(problem, subalg, ninner, τ)
end end
end end
struct SurrogateState{A,SP,SS,Wv,R,S,K1} <: State struct SurrogateState{A,SS,Wv,R,S,K1} <: State
algorithm::A algorithm::A
subproblem::SP
substate::SS substate::SS
"scalar temporary 1" "scalar temporary 1"
...@@ -43,33 +45,45 @@ function init(alg::SurrogateAlgorithm) ...@@ -43,33 +45,45 @@ function init(alg::SurrogateAlgorithm)
@inline kf1(pw) = @inbounds -divergence(pw) @inline kf1(pw) = @inbounds -divergence(pw)
k1 = Kernel{ntuple(_->-1:1, d)}(kf1) k1 = Kernel{ntuple(_->-1:1, d)}(kf1)
subprob = DualTVL1ROFOpProblem(copy(prob.g), I, prob.λ) subg = copy(g)
substate = init(alg.subalg(subprob)) subprob = DualTVL1ROFOpProblem(subg, I, prob.λ)
subalg = alg.subalg(subprob)
substate = init(subalg)
return SurrogateState(alg, subprob, substate, rv, sv, r, s, k1) # we don't want to modify ourselves if alg.subalg gives us a bad reference
@assert subalg.problem.g !== g
return SurrogateState(alg, substate, rv, sv, r, s, k1)
end end
function update_g!(subg, st::SurrogateState) function update_g!(subg, st::SurrogateState)
alg = st.algorithm alg = st.algorithm
g = alg.problem.g g = alg.problem.g
p = extend(fetch(st.substate), StaticKernels.ExtensionReplicate())
p = extend(fetch(st.substate), StaticKernels.ExtensionNothing())
sv = vec(parent(st.s))
rv = vec(st.r)
# r = Λ*p # r = Λ*p
map!(st.k1, st.r, p) map!(st.k1, st.r, p)
# s = r - g # s = r - g
st.s .= st.r - g st.s .= st.r .- g
# r = r + 1/τ * B * s # r = r - 1/τ * B * s
mul!(st.rv, alg.problem.B, st.sv, -1. / alg.τ, 1.) mul!(rv, alg.problem.B, sv, -1. / alg.τ, 1.)
subg .= st.r subg .= st.r
end end
function step!(st::SurrogateState) function step!(st::SurrogateState)
update_g!(st.subproblem.g, st) subst = st.substate
# update data g of the subproblem
update_g!(subst.algorithm.problem.g, st)
for i in 1:10 for i in 1:st.algorithm.ninner
step!(st.substate) step!(subst)
end end
return st return st
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment