Skip to content

Commit ba378dd

Browse files
committed
ADMM version JuliaFirstOrder#1
1 parent 6f626c7 commit ba378dd

File tree

11 files changed

+793
-5
lines changed

11 files changed

+793
-5
lines changed

src/ProximalAlgorithms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ include("utilities/get_assumptions.jl")
165165

166166
# algorithm implementations
167167

168+
include("algorithms/cg.jl")
169+
include("algorithms/admm.jl")
168170
include("algorithms/forward_backward.jl")
169171
include("algorithms/fast_forward_backward.jl")
170172
include("algorithms/zerofpr.jl")

src/algorithms/admm.jl

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
#
2+
# This file contains code that is derived from RegularizedLeastSquares.jl.
3+
# Original source: https://github.com/JuliaImageRecon/RegularizedLeastSquares.jl
4+
#
5+
# RegularizedLeastSquares.jl is licensed under the MIT License:
6+
#
7+
# Copyright (c) 2018: Tobias Knopp
8+
#
9+
# Permission is hereby granted, free of charge, to any person obtaining a copy
10+
# of this software and associated documentation files (the "Software"), to deal
11+
# in the Software without restriction, including without limitation the rights
12+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
# copies of the Software, and to permit persons to whom the Software is
14+
# furnished to do so, subject to the following conditions:
15+
#
16+
# The above copyright notice and this permission notice shall be included in all
17+
# copies or substantial portions of the Software.
18+
#
19+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25+
# SOFTWARE.
26+
27+
struct ADMMIteration{R,Tx,TAHb,Tg<:Tuple,TB<:Tuple,TP,TC,TCGS}
28+
x0::Tx
29+
AHb::TAHb
30+
g::Tg
31+
B::TB
32+
rho::Vector{R}
33+
P::TP
34+
P_is_inverse::Bool
35+
cg_operator::TC
36+
cg_tol::R
37+
cg_maxiter::Int
38+
y0::Vector{Tx}
39+
z0::Vector{Tx}
40+
cg_state::TCGS
41+
end
42+
43+
"""
44+
ADMMIteration(; <keyword-arguments>)
45+
46+
Iterator implementing the Alternating Direction Method of Multipliers (ADMM) algorithm.
47+
48+
This iterator solves optimization problems of the form
49+
50+
minimize ½‖Ax - b‖²₂ + ∑ᵢ gᵢ(Bᵢx)
51+
52+
where:
53+
- `A` is a linear operator
54+
- `b` is the measurement vector
55+
- `gᵢ` are proximable functions with associated linear operators `Bᵢ`
56+
57+
See also: [`ADMM`](@ref).
58+
59+
# Arguments
60+
- `x0`: initial point
61+
- `A=nothing`: forward operator. If `A` is not provided, ½‖Ax - b‖²₂ is not computed, and the algorithm will only minimize the regularization terms.
62+
- `b=nothing`: measurement vector. If `A` is provided, `b` must also be provided.
63+
- `g=()`: tuple of proximable regularization functions
64+
- `B=()`: tuple of regularization operators
65+
- `rho=ones(length(g))`: vector of augmented Lagrangian parameters (one per regularizer)
66+
- `P=nothing`: preconditioner for CG (optional)
67+
- `cg_tol=1e-6`: CG tolerance
68+
- `cg_maxiter=100`: maximum CG iterations
69+
- `y0=nothing`: initial dual variables
70+
- `z0=nothing`: initial auxiliary variables
71+
"""
72+
function ADMMIteration(;
73+
x0,
74+
A = nothing,
75+
b = nothing,
76+
g = (),
77+
B = nothing,
78+
rho = nothing,
79+
P = nothing,
80+
P_is_inverse = false,
81+
cg_tol = 1e-6,
82+
cg_maxiter = 100,
83+
y0 = nothing,
84+
z0 = nothing,
85+
)
86+
if isnothing(A) && !isnothing(b)
87+
throw(ArgumentError("A must be provided if b is given"))
88+
end
89+
if !isnothing(A) && isnothing(b)
90+
throw(ArgumentError("b must be provided if A is given"))
91+
end
92+
if !(g isa Tuple)
93+
g = (g,)
94+
end
95+
if isnothing(B)
96+
B = tuple(fill(LinearAlgebra.I, length(g))...) # Default to identity operators
97+
elseif !(B isa Tuple)
98+
B = (B,)
99+
end
100+
if length(B) != length(g)
101+
throw(ArgumentError("B and g must have the same length"))
102+
end
103+
if isnothing(rho)
104+
rho = ones(real(eltype(x0)), length(g))
105+
elseif rho isa Number
106+
rho = fill(rho, length(g))
107+
elseif !(rho isa Vector) || !all(isreal, rho)
108+
throw(ArgumentError("rho must be a vector of real numbers"))
109+
end
110+
if length(rho) != length(g)
111+
throw(ArgumentError("rho must have the same length as g"))
112+
end
113+
# Build the CG operator for the x update
114+
# If A is not provided, we assume a simple identity operator
115+
# cg_operator = A'*A + sum(rho[i] * (B[i]' * B[i]) for i in eachindex(g))
116+
cg_operator = isnothing(A) ? nothing : A' * A
117+
for i in eachindex(g)
118+
new_op = rho[i] * (B[i]' * B[i])
119+
if isnothing(cg_operator)
120+
cg_operator = new_op
121+
else
122+
cg_operator += new_op
123+
end
124+
end
125+
if isnothing(y0)
126+
y0 = [zero(x0) for _ in 1:length(g)]
127+
elseif length(y0) != length(g)
128+
throw(ArgumentError("y0 must have the same length as g"))
129+
end
130+
if isnothing(z0)
131+
z0 = [zero(x0) for _ in 1:length(g)]
132+
elseif length(z0) != length(g)
133+
throw(ArgumentError("z0 must have the same length as g"))
134+
end
135+
AHb = isnothing(A) ? nothing : (A' * b)
136+
if size(AHb) != size(x0)
137+
throw(ArgumentError("A'b must have the same size as x0"))
138+
end
139+
140+
# Create initial CGState
141+
cg_state = isnothing(P) ? CGState(x0) : PCGState(x0)
142+
143+
return ADMMIteration{eltype(rho),typeof(x0),typeof(AHb),typeof(g),typeof(B),
144+
typeof(P),typeof(cg_operator),typeof(cg_state)}(
145+
x0, AHb, g, B, rho, P, P_is_inverse, cg_operator, cg_tol, cg_maxiter, y0, z0, cg_state
146+
)
147+
end
148+
149+
Base.@kwdef mutable struct ADMMState{R,Tx}
150+
x::Tx # primal variable
151+
y::Vector{Tx} # scaled dual variables
152+
z::Vector{Tx} # auxiliary variables
153+
u::Tx # temporary variable for x update
154+
v::Tx # temporary variable for normal equations
155+
w::Vector{Tx} # temporary variables for residuals
156+
res_primal::Vector{R} # primal residual norms
157+
res_dual::Vector{R} # dual residual norms
158+
end
159+
160+
function ADMMState(iter::ADMMIteration)
161+
n_reg = length(iter.g)
162+
163+
# Initialize variables and CG state
164+
x = iter.cg_state.x # Start with initial guess
165+
y = isnothing(iter.y0) ? [zero(x) for _ in 1:n_reg] : copy.(iter.y0)
166+
z = isnothing(iter.z0) ? [zero(x) for _ in 1:n_reg] : copy.(iter.z0)
167+
168+
# Allocate temporary variables
169+
u = similar(x)
170+
v = similar(x)
171+
w = [similar(x) for _ in 1:n_reg]
172+
173+
# Initialize residuals
174+
res_primal = zeros(real(eltype(x)), n_reg)
175+
res_dual = zeros(real(eltype(x)), n_reg)
176+
177+
return ADMMState(;x, y, z, u, v, w, res_primal, res_dual)
178+
end
179+
180+
function Base.iterate(iter::ADMMIteration, state::ADMMState = ADMMState(iter))
181+
# Store old z for computing dual residuals
182+
z_old = copy.(state.z)
183+
184+
# Update x using CG
185+
if !isnothing(iter.AHb)
186+
copyto!(state.v, iter.AHb) # v = A'b
187+
else
188+
fill!(state.v, 0) # no least squares term
189+
end
190+
191+
# Add contributions from regularizers
192+
fill!(state.u, 0)
193+
for i in eachindex(iter.g)
194+
mul!(state.w[i], adjoint(iter.B[i]), state.z[i] .- state.y[i])
195+
state.u .+= iter.rho[i] .* state.w[i]
196+
end
197+
state.v .+= state.u
198+
199+
# Create new CGIteration but reuse state
200+
cg = CG(
201+
x0 = state.x,
202+
A = iter.cg_operator,
203+
b = state.v,
204+
P = iter.P,
205+
P_is_inverse = iter.P_is_inverse,
206+
state = iter.cg_state,
207+
tol = iter.cg_tol,
208+
maxit = iter.cg_maxiter,
209+
)
210+
cg() # this works in-place, updating state.x == iter.cg_state.x
211+
212+
# z-updates
213+
for i in eachindex(iter.g)
214+
mul!(state.w[i], iter.B[i], state.x)
215+
state.w[i] .+= state.y[i]
216+
prox!(state.z[i], iter.g[i], state.w[i], 1/iter.rho[i])
217+
end
218+
219+
# Update dual variables and compute residuals
220+
for i in eachindex(iter.g)
221+
mul!(state.w[i], iter.B[i], state.x)
222+
state.w[i] .-= state.z[i]
223+
state.y[i] .+= state.w[i]
224+
225+
state.res_primal[i] = norm(state.w[i])
226+
state.res_dual[i] = iter.rho[i] * norm(state.z[i] - z_old[i])
227+
end
228+
229+
return state, state
230+
end
231+
232+
default_stopping_criterion(tol, ::ADMMIteration, state::ADMMState) =
233+
all(r -> r <= tol, state.res_primal) && all(r -> r <= tol, state.res_dual)
234+
default_solution(::ADMMIteration, state::ADMMState) = state.x
235+
default_display(it, ::ADMMIteration, state::ADMMState) =
236+
@printf("%5d | Primal: %.3e, Dual: %.3e\n", it,
237+
maximum(state.res_primal), maximum(state.res_dual))
238+
239+
"""
240+
ADMM(; <keyword-arguments>)
241+
242+
Create an instance of the ADMM algorithm.
243+
244+
This algorithm solves optimization problems of the form
245+
246+
minimize ½‖Ax - b‖²₂ + ∑ᵢ gᵢ(Bᵢx)
247+
248+
where `A` is a linear operator, `b` is the measurement vector, and `gᵢ` are proximable functions with associated linear operators `Bᵢ`.
249+
250+
The returned object has type `IterativeAlgorithm{ADMMIteration}`,
251+
and can called be with the problem's arguments to trigger its solution.
252+
253+
# Arguments
254+
- `x0`: initial point
255+
- `A=nothing`: forward operator. If `A` is not provided, ½‖Ax - b‖²₂ is not computed, and the algorithm will only minimize the regularization terms.
256+
- `b=nothing`: measurement vector. If `A` is provided, `b` must also be provided.
257+
- `g=()`: tuple of proximable regularization functions
258+
- `B=()`: tuple of regularization operators
259+
- `rho=ones(length(g))`: vector of augmented Lagrangian parameters (one per regularizer)
260+
- `P=nothing`: preconditioner for CG (optional)
261+
- `cg_tol=1e-6`: CG tolerance
262+
- `cg_maxiter=100`: maximum CG iterations
263+
- `y0=nothing`: initial dual variables
264+
- `z0=nothing`: initial auxiliary variables
265+
"""
266+
ADMM(;
267+
maxit = 10_000,
268+
tol = 1e-8,
269+
stop = (iter, state) -> default_stopping_criterion(tol, iter, state),
270+
solution = default_solution,
271+
verbose = false,
272+
freq = 100,
273+
display = default_display,
274+
kwargs...,
275+
) = IterativeAlgorithm(
276+
ADMMIteration;
277+
maxit,
278+
stop,
279+
solution,
280+
verbose,
281+
freq,
282+
display,
283+
kwargs...,
284+
)
285+
286+
get_assumptions(::Type{<:ADMMIteration}) = (
287+
LeastSquaresTerm(:A => (is_linear,), :b),
288+
RepeatedOperatorTerm(:g => (is_proximable,), :B => (is_linear,)),
289+
)

0 commit comments

Comments
 (0)