@@ -3,9 +3,15 @@ module ProximalAlgorithms
33using ADTypes: ADTypes
44using DifferentiationInterface: DifferentiationInterface
55using ProximalCore
6- using ProximalCore: prox, prox!, is_smooth, is_locally_smooth, is_convex, is_strongly_convex, is_proximable
7- using OperatorCore: is_linear
6+ using ProximalCore: Zero, IndZero, convex_conjugate, prox, prox!, is_smooth, is_locally_smooth, is_convex, is_strongly_convex, is_proximable
7+ using OperatorCore: is_linear, is_symmetric, is_positive_definite
8+ using LinearAlgebra
9+ using Base. Iterators
10+ using Printf
11+
812import Base: show
13+ import Base: *
14+ import LinearAlgebra: mul!
915
1016const RealOrComplex{R} = Union{R,Complex{R}}
1117const Maybe{T} = Union{T,Nothing}
@@ -113,8 +119,39 @@ IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...)
113119 kwargs,
114120 )
115121
122+ """
123+ get_iterator(alg::IterativeAlgorithm{IteratorType}) where {IteratorType}
124+
125+ Return an iterator of type `IteratorType` constructed from the algorithm `alg`.
126+ This is a convenience function to allow for easy access to the iterator type
127+ associated with an `IterativeAlgorithm`.
128+
129+ # Example
130+ ```julia
131+ julia> using ProximalAlgorithms: CG, get_iterator
132+
133+ julia> alg = CG(maxit=3, tol=1e-8);
134+
135+ julia> iter = get_iterator(alg, A=reshape(collect(1:25)), b=collect(1:5));
136+
137+ julia> for (k, state) in enumerate(iter)
138+ if k >= alg.maxit || alg.stop(iter, state)
139+ alg.verbose && alg.display(k, iter, state)
140+ return (alg.solution(iter, state), k)
141+ end
142+ alg.verbose && mod(k, alg.freq) == 0 && alg.display(k, iter, state)
143+ end
144+ 1 | 7.416e+00
145+ 2 | 2.742e+00
146+ 3 | 2.300e+01
147+ ([0.5581699346405239, 0.31633986928104635, 0.07450980392156867, -0.16732026143790907, -0.4091503267973867], 3)
148+ ```
149+ """
150+ get_iterator (alg:: IterativeAlgorithm{IteratorType} ; kwargs... ) where {IteratorType} =
151+ IteratorType (; alg. kwargs... , kwargs... )
152+
116153function (alg:: IterativeAlgorithm{IteratorType} )(; kwargs... ) where {IteratorType}
117- iter = IteratorType (; alg. kwargs ... , kwargs... )
154+ iter = get_iterator ( alg; kwargs... )
118155 for (k, state) in enumerate (iter)
119156 if k >= alg. maxit || alg. stop (iter, state)
120157 alg. verbose && alg. display (k, iter, state)
0 commit comments