1+ #
2+ # This file contains code that is derived from RegularizedLeastSquares.jl.
3+ # Original source: https://github.com/JuliaImageRecon/RegularizedLeastSquares.jl
4+ #
5+ # RegularizedLeastSquares.jl is licensed under the MIT License:
6+ #
7+ # Copyright (c) 2018: Tobias Knopp
8+ #
9+ # Permission is hereby granted, free of charge, to any person obtaining a copy
10+ # of this software and associated documentation files (the "Software"), to deal
11+ # in the Software without restriction, including without limitation the rights
12+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+ # copies of the Software, and to permit persons to whom the Software is
14+ # furnished to do so, subject to the following conditions:
15+ #
16+ # The above copyright notice and this permission notice shall be included in all
17+ # copies or substantial portions of the Software.
18+ #
19+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25+ # SOFTWARE.
26+
27+ struct ADMMIteration{R,Tx,TAHb,Tg<: Tuple ,TB<: Tuple ,TP,TC,TCGS}
28+ x0:: Tx
29+ AHb:: TAHb
30+ g:: Tg
31+ B:: TB
32+ rho:: Vector{R}
33+ P:: TP
34+ P_is_inverse:: Bool
35+ cg_operator:: TC
36+ cg_tol:: R
37+ cg_maxiter:: Int
38+ y0:: Vector{Tx}
39+ z0:: Vector{Tx}
40+ cg_state:: TCGS
41+ end
42+
43+ """
44+ ADMMIteration(; <keyword-arguments>)
45+
46+ Iterator implementing the Alternating Direction Method of Multipliers (ADMM) algorithm.
47+
48+ This iterator solves optimization problems of the form
49+
50+ minimize ½‖Ax - b‖²₂ + ∑ᵢ gᵢ(Bᵢx)
51+
52+ where:
53+ - `A` is a linear operator
54+ - `b` is the measurement vector
55+ - `gᵢ` are proximable functions with associated linear operators `Bᵢ`
56+
57+ See also: [`ADMM`](@ref).
58+
59+ # Arguments
60+ - `x0`: initial point
61+ - `A=nothing`: forward operator. If `A` is not provided, ½‖Ax - b‖²₂ is not computed, and the algorithm will only minimize the regularization terms.
62+ - `b=nothing`: measurement vector. If `A` is provided, `b` must also be provided.
63+ - `g=()`: tuple of proximable regularization functions
64+ - `B=()`: tuple of regularization operators
65+ - `rho=ones(length(g))`: vector of augmented Lagrangian parameters (one per regularizer)
66+ - `P=nothing`: preconditioner for CG (optional)
67+ - `cg_tol=1e-6`: CG tolerance
68+ - `cg_maxiter=100`: maximum CG iterations
69+ - `y0=nothing`: initial dual variables
70+ - `z0=nothing`: initial auxiliary variables
71+ """
72+ function ADMMIteration (;
73+ x0,
74+ A = nothing ,
75+ b = nothing ,
76+ g = (),
77+ B = nothing ,
78+ rho = nothing ,
79+ P = nothing ,
80+ P_is_inverse = false ,
81+ cg_tol = 1e-6 ,
82+ cg_maxiter = 100 ,
83+ y0 = nothing ,
84+ z0 = nothing ,
85+ )
86+ if isnothing (A) && ! isnothing (b)
87+ throw (ArgumentError (" A must be provided if b is given" ))
88+ end
89+ if ! isnothing (A) && isnothing (b)
90+ throw (ArgumentError (" b must be provided if A is given" ))
91+ end
92+ if ! (g isa Tuple)
93+ g = (g,)
94+ end
95+ if isnothing (B)
96+ B = tuple (fill (LinearAlgebra. I, length (g))... ) # Default to identity operators
97+ elseif ! (B isa Tuple)
98+ B = (B,)
99+ end
100+ if length (B) != length (g)
101+ throw (ArgumentError (" B and g must have the same length" ))
102+ end
103+ if isnothing (rho)
104+ rho = ones (real (eltype (x0)), length (g))
105+ elseif rho isa Number
106+ rho = fill (rho, length (g))
107+ elseif ! (rho isa Vector) || ! all (isreal, rho)
108+ throw (ArgumentError (" rho must be a vector of real numbers" ))
109+ end
110+ if length (rho) != length (g)
111+ throw (ArgumentError (" rho must have the same length as g" ))
112+ end
113+ # Build the CG operator for the x update
114+ # If A is not provided, we assume a simple identity operator
115+ # cg_operator = A'*A + sum(rho[i] * (B[i]' * B[i]) for i in eachindex(g))
116+ cg_operator = isnothing (A) ? nothing : A' * A
117+ for i in eachindex (g)
118+ new_op = rho[i] * (B[i]' * B[i])
119+ if isnothing (cg_operator)
120+ cg_operator = new_op
121+ else
122+ cg_operator += new_op
123+ end
124+ end
125+ if isnothing (y0)
126+ y0 = [zero (x0) for _ in 1 : length (g)]
127+ elseif length (y0) != length (g)
128+ throw (ArgumentError (" y0 must have the same length as g" ))
129+ end
130+ if isnothing (z0)
131+ z0 = [zero (x0) for _ in 1 : length (g)]
132+ elseif length (z0) != length (g)
133+ throw (ArgumentError (" z0 must have the same length as g" ))
134+ end
135+ AHb = isnothing (A) ? nothing : (A' * b)
136+ if size (AHb) != size (x0)
137+ throw (ArgumentError (" A'b must have the same size as x0" ))
138+ end
139+
140+ # Create initial CGState
141+ cg_state = isnothing (P) ? CGState (x0) : PCGState (x0)
142+
143+ return ADMMIteration{eltype (rho),typeof (x0),typeof (AHb),typeof (g),typeof (B),
144+ typeof (P),typeof (cg_operator),typeof (cg_state)}(
145+ x0, AHb, g, B, rho, P, P_is_inverse, cg_operator, cg_tol, cg_maxiter, y0, z0, cg_state
146+ )
147+ end
148+
149+ Base. @kwdef mutable struct ADMMState{R,Tx}
150+ x:: Tx # primal variable
151+ y:: Vector{Tx} # scaled dual variables
152+ z:: Vector{Tx} # auxiliary variables
153+ u:: Tx # temporary variable for x update
154+ v:: Tx # temporary variable for normal equations
155+ w:: Vector{Tx} # temporary variables for residuals
156+ res_primal:: Vector{R} # primal residual norms
157+ res_dual:: Vector{R} # dual residual norms
158+ end
159+
160+ function ADMMState (iter:: ADMMIteration )
161+ n_reg = length (iter. g)
162+
163+ # Initialize variables and CG state
164+ x = iter. cg_state. x # Start with initial guess
165+ y = isnothing (iter. y0) ? [zero (x) for _ in 1 : n_reg] : copy .(iter. y0)
166+ z = isnothing (iter. z0) ? [zero (x) for _ in 1 : n_reg] : copy .(iter. z0)
167+
168+ # Allocate temporary variables
169+ u = similar (x)
170+ v = similar (x)
171+ w = [similar (x) for _ in 1 : n_reg]
172+
173+ # Initialize residuals
174+ res_primal = zeros (real (eltype (x)), n_reg)
175+ res_dual = zeros (real (eltype (x)), n_reg)
176+
177+ return ADMMState (;x, y, z, u, v, w, res_primal, res_dual)
178+ end
179+
180+ function Base. iterate (iter:: ADMMIteration , state:: ADMMState = ADMMState (iter))
181+ # Store old z for computing dual residuals
182+ z_old = copy .(state. z)
183+
184+ # Update x using CG
185+ if ! isnothing (iter. AHb)
186+ copyto! (state. v, iter. AHb) # v = A'b
187+ else
188+ fill! (state. v, 0 ) # no least squares term
189+ end
190+
191+ # Add contributions from regularizers
192+ fill! (state. u, 0 )
193+ for i in eachindex (iter. g)
194+ mul! (state. w[i], adjoint (iter. B[i]), state. z[i] .- state. y[i])
195+ state. u .+ = iter. rho[i] .* state. w[i]
196+ end
197+ state. v .+ = state. u
198+
199+ # Create new CGIteration but reuse state
200+ cg = CG (
201+ x0 = state. x,
202+ A = iter. cg_operator,
203+ b = state. v,
204+ P = iter. P,
205+ P_is_inverse = iter. P_is_inverse,
206+ state = iter. cg_state,
207+ tol = iter. cg_tol,
208+ maxit = iter. cg_maxiter,
209+ )
210+ cg () # this works in-place, updating state.x == iter.cg_state.x
211+
212+ # z-updates
213+ for i in eachindex (iter. g)
214+ mul! (state. w[i], iter. B[i], state. x)
215+ state. w[i] .+ = state. y[i]
216+ prox! (state. z[i], iter. g[i], state. w[i], 1 / iter. rho[i])
217+ end
218+
219+ # Update dual variables and compute residuals
220+ for i in eachindex (iter. g)
221+ mul! (state. w[i], iter. B[i], state. x)
222+ state. w[i] .- = state. z[i]
223+ state. y[i] .+ = state. w[i]
224+
225+ state. res_primal[i] = norm (state. w[i])
226+ state. res_dual[i] = iter. rho[i] * norm (state. z[i] - z_old[i])
227+ end
228+
229+ return state, state
230+ end
231+
232+ default_stopping_criterion (tol, :: ADMMIteration , state:: ADMMState ) =
233+ all (r -> r <= tol, state. res_primal) && all (r -> r <= tol, state. res_dual)
234+ default_solution (:: ADMMIteration , state:: ADMMState ) = state. x
235+ default_display (it, :: ADMMIteration , state:: ADMMState ) =
236+ @printf (" %5d | Primal: %.3e, Dual: %.3e\n " , it,
237+ maximum (state. res_primal), maximum (state. res_dual))
238+
239+ """
240+ ADMM(; <keyword-arguments>)
241+
242+ Create an instance of the ADMM algorithm.
243+
244+ This algorithm solves optimization problems of the form
245+
246+ minimize ½‖Ax - b‖²₂ + ∑ᵢ gᵢ(Bᵢx)
247+
248+ where `A` is a linear operator, `b` is the measurement vector, and `gᵢ` are proximable functions with associated linear operators `Bᵢ`.
249+
250+ The returned object has type `IterativeAlgorithm{ADMMIteration}`,
251+ and can called be with the problem's arguments to trigger its solution.
252+
253+ # Arguments
254+ - `x0`: initial point
255+ - `A=nothing`: forward operator. If `A` is not provided, ½‖Ax - b‖²₂ is not computed, and the algorithm will only minimize the regularization terms.
256+ - `b=nothing`: measurement vector. If `A` is provided, `b` must also be provided.
257+ - `g=()`: tuple of proximable regularization functions
258+ - `B=()`: tuple of regularization operators
259+ - `rho=ones(length(g))`: vector of augmented Lagrangian parameters (one per regularizer)
260+ - `P=nothing`: preconditioner for CG (optional)
261+ - `cg_tol=1e-6`: CG tolerance
262+ - `cg_maxiter=100`: maximum CG iterations
263+ - `y0=nothing`: initial dual variables
264+ - `z0=nothing`: initial auxiliary variables
265+ """
266+ ADMM (;
267+ maxit = 10_000 ,
268+ tol = 1e-8 ,
269+ stop = (iter, state) -> default_stopping_criterion (tol, iter, state),
270+ solution = default_solution,
271+ verbose = false ,
272+ freq = 100 ,
273+ display = default_display,
274+ kwargs... ,
275+ ) = IterativeAlgorithm (
276+ ADMMIteration;
277+ maxit,
278+ stop,
279+ solution,
280+ verbose,
281+ freq,
282+ display,
283+ kwargs... ,
284+ )
285+
286+ get_assumptions (:: Type{<:ADMMIteration} ) = (
287+ LeastSquaresTerm (:A => (is_linear,), :b ),
288+ RepeatedOperatorTerm (:g => (is_proximable,), :B => (is_linear,)),
289+ )
0 commit comments