diff --git a/src/surrogate.jl b/src/surrogate.jl index dcdacf911f7c7a1ccd95f75b7abaa4e72abe0181..9914b9995940e0edea5bf01bb6a797c870343e13 100644 --- a/src/surrogate.jl +++ b/src/surrogate.jl @@ -3,16 +3,18 @@ using LinearAlgebra: I struct SurrogateAlgorithm{P} <: Algorithm{P} problem::P subalg::Function + "number of inner iterations" + ninner::Int τ::Float64 function SurrogateAlgorithm(problem::DualTVL1ROFOpProblem; - τ=2*normB(problem), subalg=x->ProjectedGradient(x)) - return new{typeof(problem)}(problem, subalg, τ) + ninner=1, + τ=2*normB(problem), subalg=x->ProjGradAlgorithm(x)) + return new{typeof(problem)}(problem, subalg, ninner, τ) end end -struct SurrogateState{A,SP,SS,Wv,R,S,K1} <: State +struct SurrogateState{A,SS,Wv,R,S,K1} <: State algorithm::A - subproblem::SP substate::SS "scalar temporary 1" @@ -43,33 +45,45 @@ function init(alg::SurrogateAlgorithm) @inline kf1(pw) = @inbounds -divergence(pw) k1 = Kernel{ntuple(_->-1:1, d)}(kf1) - subprob = DualTVL1ROFOpProblem(copy(prob.g), I, prob.λ) - substate = init(alg.subalg(subprob)) + subg = copy(g) + subprob = DualTVL1ROFOpProblem(subg, I, prob.λ) + subalg = alg.subalg(subprob) + substate = init(subalg) - return SurrogateState(alg, subprob, substate, rv, sv, r, s, k1) + # we don't want to modify ourselves if alg.subalg gives us a bad reference + @assert subalg.problem.g !== g + + return SurrogateState(alg, substate, rv, sv, r, s, k1) end function update_g!(subg, st::SurrogateState) alg = st.algorithm g = alg.problem.g - p = extend(fetch(st.substate), StaticKernels.ExtensionReplicate()) + + p = extend(fetch(st.substate), StaticKernels.ExtensionNothing()) + + sv = vec(parent(st.s)) + rv = vec(st.r) # r = Λ*p map!(st.k1, st.r, p) # s = r - g - st.s .= st.r - g - # r = r + 1/τ * B * s - mul!(st.rv, alg.problem.B, st.sv, -1. / alg.τ, 1.) + st.s .= st.r .- g + # r = r - 1/τ * B * s + mul!(rv, alg.problem.B, sv, -1. / alg.τ, 1.) subg .= st.r end function step!(st::SurrogateState) - update_g!(st.subproblem.g, st) + subst = st.substate + + # update data g of the subproblem + update_g!(subst.algorithm.problem.g, st) - for i in 1:10 - step!(st.substate) + for i in 1:st.algorithm.ninner + step!(subst) end return st