From 832d6cd0fc10cfa9a4a331d9efd4f852c578da9b Mon Sep 17 00:00:00 2001 From: Tamas Hakkel Date: Tue, 2 Dec 2025 12:03:54 +0100 Subject: [PATCH] Improve verbose mode - `summary` function added to the `IterativeAlgorithm` struct. This function returns a tuple of pairs where the key is the column title. E.g.: ("" => it, , "f(xg)" => state.f_xg, ...) - The `display` function is modified to call summary and display the result. When `it = 0` is passed, then only a table header is printed. - When `freq` in `IterativeAlgorithm` is set to 0, then only a single line is printed after the iteration stops. The format of this line is like: "total iterations = 43, f(xg) = 3.524e-3, ..." - `default_display` function now accepts `printfunc` optional argument. The default value is `println`, and this argument allows replacing it, e.g., with a proper logger. --- src/ProximalAlgorithms.jl | 70 +++++++++++++++++++++---- src/algorithms/davis_yin.jl | 23 ++++---- src/algorithms/douglas_rachford.jl | 24 +++++---- src/algorithms/drls.jl | 25 +++++---- src/algorithms/fast_forward_backward.jl | 20 ++++--- src/algorithms/forward_backward.jl | 15 +++--- src/algorithms/li_lin.jl | 15 +++--- src/algorithms/panoc.jl | 20 ++++--- src/algorithms/panocplus.jl | 22 ++++---- src/algorithms/primal_dual.jl | 33 +++++++----- src/algorithms/sfista.jl | 31 +++++++---- src/algorithms/zerofpr.jl | 26 +++++---- 12 files changed, 209 insertions(+), 115 deletions(-) diff --git a/src/ProximalAlgorithms.jl b/src/ProximalAlgorithms.jl index e248613..864ea11 100644 --- a/src/ProximalAlgorithms.jl +++ b/src/ProximalAlgorithms.jl @@ -4,6 +4,7 @@ using ADTypes: ADTypes using DifferentiationInterface: DifferentiationInterface using ProximalCore using ProximalCore: prox, prox! +using Printf const RealOrComplex{R} = Union{R,Complex{R}} const Maybe{T} = Union{T,Nothing} @@ -55,18 +56,19 @@ include("accel/noaccel.jl") # algorithm interface -struct IterativeAlgorithm{IteratorType,H,S,D,K} +struct IterativeAlgorithm{IteratorType,H,S,I,D,K} maxit::Int stop::H solution::S verbose::Bool freq::Int + summary::I display::D kwargs::K end """ - IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...) + IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...) Wrapper for an iterator type `T`, adding termination and verbosity options on top of it. @@ -75,7 +77,7 @@ The resulting "algorithm" object `alg` can be called on a set of keyword argumen to `kwargs` and passed on to `T` to construct an iterator which will be looped over. Specifically, if an algorithm is constructed as - alg = IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...) + alg = IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...) then calling it with @@ -88,7 +90,7 @@ will internally loop over an iterator constructed as # Note This constructor is not meant to be used directly: instead, algorithm-specific constructors should be defined on top of it and exposed to the user, that set appropriate default functions -for `stop`, `solution`, `display`. +for `stop`, `solution`, `summary`, `display`. # Arguments * `T::Type`: iterator type to use @@ -97,28 +99,78 @@ for `stop`, `solution`, `display`. * `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution * `verbose::Bool`: whether the algorithm state should be displayed * `freq::Int`: every how many iterations to display the algorithm state -* `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +* `summary::Function`: function returning a summary of the iteration state, `summary(k::Int, iter::T, state)` should return a vector of pairs `(name, value)` +* `display::Function`: display function, `display(k::Int, alg, iter::T, state)` should display a summary of the iteration state * `kwargs...`: keyword arguments to pass on to `T` when constructing the iterator """ -IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...) = - IterativeAlgorithm{T,typeof(stop),typeof(solution),typeof(display),typeof(kwargs)}( +IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...) = + IterativeAlgorithm{T,typeof(stop),typeof(solution),typeof(summary),typeof(display),typeof(kwargs)}( maxit, stop, solution, verbose, freq, + summary, display, kwargs, ) +function default_display(k, alg, iter, state, printfunc=println) + if alg.freq > 0 + summary = alg.summary(k, iter, state) + column_widths = map(pair -> max(length(pair.first), pair.second isa Integer ? 5 : 9), summary) + if k == 0 + keys = map(first, summary) + first_line = [_get_centered_text(key, width) for (width, key) in zip(column_widths, keys)] + printfunc(join(first_line, " | ")) + second_line = [repeat('-', width) for width in column_widths] + printfunc(join(second_line, "-|-"), "-") + else + values = map(last, summary) + parts = [_format_value(value, width) for (width, value) in zip(column_widths, values)] + printfunc(join(parts, " | ")) + end + else + summary = alg.summary(k, iter, state) + if summary[1].first == "" + summary = ("total iterations" => k, summary[2:end]...) + end + items = map(pair -> @sprintf("%s=%s", pair.first, _format_value(pair.second, 0)), summary) + printfunc(join(items, ", ")) + end +end + +function _get_centered_text(text, width) + l = length(text) + if l >= width + return text + end + left_padding = div(width - l, 2) + right_padding = width - l - left_padding + return repeat(" ", left_padding) * text * repeat(" ", right_padding) +end + +function _format_value(value, width) + if value isa Integer + return @sprintf("%*d", width, value) + elseif value isa Float64 || value isa Float32 + return @sprintf("%*.3e", width, value) + else + return @sprintf("%*s", width, string(value)) + end +end + function (alg::IterativeAlgorithm{IteratorType})(; kwargs...) where {IteratorType} iter = IteratorType(; alg.kwargs..., kwargs...) for (k, state) in enumerate(iter) + if k == 1 && alg.verbose && alg.freq > 0 + alg.display(0, alg, iter, state) + end if k >= alg.maxit || alg.stop(iter, state) - alg.verbose && alg.display(k, iter, state) + alg.verbose && alg.display(k, alg, iter, state) return (alg.solution(iter, state), k) end - alg.verbose && mod(k, alg.freq) == 0 && alg.display(k, iter, state) + alg.verbose && alg.freq > 0 && mod(k, alg.freq) == 0 && alg.display(k, alg, iter, state) end end diff --git a/src/algorithms/davis_yin.jl b/src/algorithms/davis_yin.jl index 55a4c79..a77b2d9 100644 --- a/src/algorithms/davis_yin.jl +++ b/src/algorithms/davis_yin.jl @@ -44,12 +44,14 @@ end Base.IteratorSize(::Type{<:DavisYinIteration}) = Base.IsInfinite() -struct DavisYinState{T} +struct DavisYinState{T,R} z::T xg::T + f_xg::R grad_f_xg::T z_half::T xh::T + g_xh::R res::T end @@ -58,10 +60,10 @@ function Base.iterate(iter::DavisYinIteration) xg, = prox(iter.g, z, iter.gamma) f_xg, grad_f_xg = value_and_gradient(iter.f, xg) z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg - xh, = prox(iter.h, z_half, iter.gamma) + xh, g_xh = prox(iter.h, z_half, iter.gamma) res = xh - xg z .+= iter.lambda .* res - state = DavisYinState(z, xg, grad_f_xg, z_half, xh, res) + state = DavisYinState(z, xg, f_xg, grad_f_xg, z_half, xh, g_xh, res) return state, state end @@ -79,8 +81,8 @@ end default_stopping_criterion(tol, ::DavisYinIteration, state::DavisYinState) = norm(state.res, Inf) <= tol default_solution(::DavisYinIteration, state::DavisYinState) = state.xh -default_display(it, ::DavisYinIteration, state::DavisYinState) = - @printf("%5d | %.3e\n", it, norm(state.res, Inf)) +default_iteration_summary(it, ::DavisYinIteration, state::DavisYinState) = + ("" => it, "f(xg)" => state.f_xg, "g(xh)" => state.g_xh, "‖xg - xh‖" => norm(state.res, Inf)) """ DavisYin(; ) @@ -101,11 +103,12 @@ See also: [`DavisYinIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `DavisYinIteration` constructor upon call # References @@ -118,6 +121,7 @@ DavisYin(; solution = default_solution, verbose = false, freq = 100, + summary=default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -127,6 +131,7 @@ DavisYin(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/douglas_rachford.jl b/src/algorithms/douglas_rachford.jl index a3fca42..97768d1 100644 --- a/src/algorithms/douglas_rachford.jl +++ b/src/algorithms/douglas_rachford.jl @@ -42,11 +42,13 @@ end Base.IteratorSize(::Type{<:DouglasRachfordIteration}) = Base.IsInfinite() -Base.@kwdef struct DouglasRachfordState{Tx} +Base.@kwdef struct DouglasRachfordState{Tx,R} x::Tx y::Tx = similar(x) + f_y::R = real(eltype(x))(0) r::Tx = similar(x) z::Tx = similar(x) + g_z::R = real(eltype(x))(0) res::Tx = similar(x) end @@ -54,11 +56,12 @@ function Base.iterate( iter::DouglasRachfordIteration, state::DouglasRachfordState = DouglasRachfordState(x = copy(iter.x0)), ) - prox!(state.y, iter.f, state.x, iter.gamma) + f_y = prox!(state.y, iter.f, state.x, iter.gamma) state.r .= 2 .* state.y .- state.x - prox!(state.z, iter.g, state.r, iter.gamma) + g_z = prox!(state.z, iter.g, state.r, iter.gamma) state.res .= state.y .- state.z state.x .-= state.res + state = DouglasRachfordState(state.x, state.y, f_y, state.r, state.z, g_z, state.res) return state, state end @@ -68,8 +71,8 @@ default_stopping_criterion( state::DouglasRachfordState, ) = norm(state.res, Inf) / iter.gamma <= tol default_solution(::DouglasRachfordIteration, state::DouglasRachfordState) = state.y -default_display(it, iter::DouglasRachfordIteration, state::DouglasRachfordState) = - @printf("%5d | %.3e\n", it, norm(state.res, Inf) / iter.gamma) +default_iteration_summary(it, iter::DouglasRachfordIteration, state::DouglasRachfordState) = + ("" => it, "f(y)" => state.f_y, "g(z)" => state.g_z, "‖y - z‖" => norm(state.res, Inf) / iter.gamma) """ DouglasRachford(; ) @@ -88,11 +91,12 @@ See also: [`DouglasRachfordIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=1_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `DouglasRachfordIteration` constructor upon call # References @@ -105,6 +109,7 @@ DouglasRachford(; solution = default_solution, verbose = false, freq = 100, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -114,6 +119,7 @@ DouglasRachford(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/drls.jl b/src/algorithms/drls.jl index f7aed73..a9da79b 100644 --- a/src/algorithms/drls.jl +++ b/src/algorithms/drls.jl @@ -197,13 +197,13 @@ end default_stopping_criterion(tol, ::DRLSIteration, state::DRLSState) = norm(state.res, Inf) / state.gamma <= tol default_solution(::DRLSIteration, state::DRLSState) = state.v -default_display(it, ::DRLSIteration, state::DRLSState) = @printf( - "%5d | %.3e | %.3e | %.3e\n", - it, - state.gamma, - norm(state.res, Inf) / state.gamma, - state.tau, -) +default_iteration_summary(it, ::DRLSIteration, state::DRLSState) = + ("" => it, + "f(u)" => state.f_u, + "g(v)" => state.g_v, + "γ" => state.gamma, + "‖u - v‖/γ" => norm(state.res, Inf) / state.gamma, + "τ" => state.tau) """ DRLS(; ) @@ -224,11 +224,12 @@ See also: [`DRLSIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=1_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=10`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=10`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `DRLSIteration` constructor upon call # References @@ -241,6 +242,7 @@ DRLS(; solution = default_solution, verbose = false, freq = 10, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -250,6 +252,7 @@ DRLS(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/fast_forward_backward.jl b/src/algorithms/fast_forward_backward.jl index c4ccb3f..c51acb5 100644 --- a/src/algorithms/fast_forward_backward.jl +++ b/src/algorithms/fast_forward_backward.jl @@ -150,8 +150,13 @@ default_stopping_criterion( state::FastForwardBackwardState, ) = norm(state.res, Inf) / state.gamma <= tol default_solution(::FastForwardBackwardIteration, state::FastForwardBackwardState) = state.z -default_display(it, ::FastForwardBackwardIteration, state::FastForwardBackwardState) = - @printf("%5d | %.3e | %.3e\n", it, state.gamma, norm(state.res, Inf) / state.gamma) +default_iteration_summary(it, iter::FastForwardBackwardIteration, state::FastForwardBackwardState) = begin + if iter.adaptive + ("" => it, "f(x)" => state.f_x, "g(z)" => state.g_z, "γ" => state.gamma, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma) + else + ("" => it, "f(x)" => state.f_x, "g(z)" => state.g_z, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma) + end +end """ FastForwardBackward(; ) @@ -172,11 +177,12 @@ See also: [`FastForwardBackwardIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `FastForwardBackwardIteration` constructor upon call # References @@ -190,6 +196,7 @@ FastForwardBackward(; solution = default_solution, verbose = false, freq = 100, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -199,6 +206,7 @@ FastForwardBackward(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/forward_backward.jl b/src/algorithms/forward_backward.jl index 574389b..680a589 100644 --- a/src/algorithms/forward_backward.jl +++ b/src/algorithms/forward_backward.jl @@ -125,8 +125,8 @@ end default_stopping_criterion(tol, ::ForwardBackwardIteration, state::ForwardBackwardState) = norm(state.res, Inf) / state.gamma <= tol default_solution(::ForwardBackwardIteration, state::ForwardBackwardState) = state.z -default_display(it, ::ForwardBackwardIteration, state::ForwardBackwardState) = - @printf("%5d | %.3e | %.3e\n", it, state.gamma, norm(state.res, Inf) / state.gamma) +default_iteration_summary(it, ::ForwardBackwardIteration, state::ForwardBackwardState) = + ("" => it, "γ" => state.gamma, "f(x)" => state.f_x, "g(z)" => state.g_z, "‖x - y‖/γ" => norm(state.res, Inf) / state.gamma) """ ForwardBackward(; ) @@ -147,11 +147,12 @@ See also: [`ForwardBackwardIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `ForwardBackwardIteration` constructor upon call # References @@ -165,6 +166,7 @@ ForwardBackward(; solution = default_solution, verbose = false, freq = 100, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -174,6 +176,7 @@ ForwardBackward(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/li_lin.jl b/src/algorithms/li_lin.jl index 0889024..01203cc 100644 --- a/src/algorithms/li_lin.jl +++ b/src/algorithms/li_lin.jl @@ -148,8 +148,8 @@ end default_stopping_criterion(tol, ::LiLinIteration, state::LiLinState) = norm(state.res, Inf) / state.gamma <= tol default_solution(::LiLinIteration, state::LiLinState) = state.z -default_display(it, ::LiLinIteration, state::LiLinState) = - @printf("%5d | %.3e | %.3e\n", it, state.gamma, norm(state.res, Inf) / state.gamma) +default_iteration_summary(it, ::LiLinIteration, state::LiLinState) = + ("" => it, "γ" => state.gamma, "f(y)" => state.f_y, "g(z)" => state.g_z, "‖y - z‖/γ" => norm(state.res, Inf) / state.gamma) """ LiLin(; ) @@ -171,11 +171,12 @@ See also: [`LiLinIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `LiLinIteration` constructor upon call # References @@ -188,6 +189,7 @@ LiLin(; solution = default_solution, verbose = false, freq = 100, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -197,6 +199,7 @@ LiLin(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/panoc.jl b/src/algorithms/panoc.jl index c7f2558..66e8826 100644 --- a/src/algorithms/panoc.jl +++ b/src/algorithms/panoc.jl @@ -257,13 +257,8 @@ end default_stopping_criterion(tol, ::PANOCIteration, state::PANOCState) = norm(state.res, Inf) / state.gamma <= tol default_solution(::PANOCIteration, state::PANOCState) = state.z -default_display(it, ::PANOCIteration, state::PANOCState) = @printf( - "%5d | %.3e | %.3e | %.3e\n", - it, - state.gamma, - norm(state.res, Inf) / state.gamma, - state.tau, -) +default_iteration_summary(it, ::PANOCIteration, state::PANOCState) = + ("" => it, "f(Ax)" => state.f_Ax, "g(z)" => state.g_z, "γ" => state.gamma, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma) """ PANOC(; ) @@ -284,11 +279,12 @@ See also: [`PANOCIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=1_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=10`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=10`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `PANOCIteration` constructor upon call # References @@ -301,6 +297,7 @@ PANOC(; solution = default_solution, verbose = false, freq = 10, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -310,6 +307,7 @@ PANOC(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/panocplus.jl b/src/algorithms/panocplus.jl index d407039..da65b7a 100644 --- a/src/algorithms/panocplus.jl +++ b/src/algorithms/panocplus.jl @@ -80,7 +80,7 @@ f_model(iter::PANOCplusIteration, state::PANOCplusState) = function Base.iterate(iter::PANOCplusIteration{R}) where {R} x = copy(iter.x0) Ax = iter.A * x - f_Ax, grad_f_Ax = value_and_gradient(iter.f, Ax) + f_Ax, grad_f_Ax = value_and_gradient(iter.f, Ax) gamma = iter.gamma === nothing ? iter.alpha / lower_bound_smoothness_constant(iter.f, iter.A, x, grad_f_Ax) : @@ -242,13 +242,8 @@ end default_stopping_criterion(tol, ::PANOCplusIteration, state::PANOCplusState) = norm((state.res / state.gamma) - state.At_grad_f_Ax + state.At_grad_f_Az, Inf) <= tol default_solution(::PANOCplusIteration, state::PANOCplusState) = state.z -default_display(it, ::PANOCplusIteration, state::PANOCplusState) = @printf( - "%5d | %.3e | %.3e | %.3e\n", - it, - state.gamma, - norm(state.res, Inf) / state.gamma, - state.tau, -) +default_iteration_summary(it, ::PANOCplusIteration, state::PANOCplusState) = + ("" => it, "f(Ax)" => state.f_Ax, "g(z)" => state.g_z, "γ" => state.gamma, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma) """ PANOCplus(; ) @@ -269,11 +264,12 @@ See also: [`PANOCplusIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=1_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=10`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=10`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `PANOCplusIteration` constructor upon call # References @@ -286,6 +282,7 @@ PANOCplus(; solution = default_solution, verbose = false, freq = 10, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -295,6 +292,7 @@ PANOCplus(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/primal_dual.jl b/src/algorithms/primal_dual.jl index 9077da7..5daefb4 100644 --- a/src/algorithms/primal_dual.jl +++ b/src/algorithms/primal_dual.jl @@ -213,8 +213,8 @@ end default_stopping_criterion(tol, ::AFBAIteration, state::AFBAState) = norm(state.FPR_x, Inf) + norm(state.FPR_y, Inf) <= tol default_solution(::AFBAIteration, state::AFBAState) = (state.xbar, state.ybar) -default_display(it, ::AFBAIteration, state::AFBAState) = - @printf("%6d | %7.4e\n", it, norm(state.FPR_x, Inf) + norm(state.FPR_y, Inf)) +default_iteration_summary(it, ::AFBAIteration, state::AFBAState) = + ("" => it, "‖x̄ - x‖" => norm(state.FPR_x, Inf), "‖ȳ - y‖" => norm(state.FPR_y, Inf)) """ AFBA(; ) @@ -236,11 +236,12 @@ See also: [`AFBAIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-5`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `AFBAIteration` constructor upon call # References @@ -254,6 +255,7 @@ AFBA(; solution = default_solution, verbose = false, freq = 100, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -263,6 +265,7 @@ AFBA(; solution, verbose, freq, + summary, display, kwargs..., ) @@ -287,11 +290,12 @@ See also: [`VuCondatIteration`](@ref), [`AFBAIteration`](@ref), [`IterativeAlgor # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-5`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `AFBAIteration` constructor upon call # References @@ -319,11 +323,12 @@ See also: [`ChambollePockIteration`](@ref), [`AFBAIteration`](@ref), [`Iterative # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-5`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `AFBAIteration` constructor upon call # References diff --git a/src/algorithms/sfista.jl b/src/algorithms/sfista.jl index 3e9458e..9c92aa6 100644 --- a/src/algorithms/sfista.jl +++ b/src/algorithms/sfista.jl @@ -44,6 +44,7 @@ Base.@kwdef struct SFISTAIteration{R,C<:Union{R,Complex{R}},Tx<:AbstractArray{C} g::Th = Zero() Lf::R mf::R = real(eltype(Lf))(0.0) + termination_type::Symbol = :classic # can be :AIPP or :classic (default) end Base.IteratorSize(::Type{<:SFISTAIteration}) = Base.IsInfinite() @@ -60,6 +61,7 @@ Base.@kwdef mutable struct SFISTAState{R,Tx} APrev::R = real(eltype(yPrev))(1.0) # previous A (helper variable). A::R = real(eltype(yPrev))(0.0) # helper variable (see [3]). gradf_xt::Tx = zero(yPrev) # array containing ∇f(xt). + res_norm::R = real(eltype(yPrev))(0.0) # norm of the residual (for stopping criterion). end function Base.iterate( @@ -88,8 +90,8 @@ function Base.iterate( end # Different stopping conditions (sc). Returns the current residual value and whether or not a stopping condition holds. -function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_type) - if termination_type == "AIPP" +function calc_residual!(state::SFISTAState, iter::SFISTAIteration) + if iter.termination_type == :AIPP # AIPP-style termination [4]. The main inclusion is: r ∈ ∂_η(f + h)(y). r = (iter.y0 - state.x) / state.A η = (norm(iter.y0 - state.y)^2 - norm(state.x - state.y)^2) / (2 * state.A) @@ -101,10 +103,16 @@ function check_sc(state::SFISTAState, iter::SFISTAIteration, tol, termination_ty r = gradf_y - state.gradf_xt + (state.xt - state.y) / λ2 res = norm(r) end - return res, (res <= tol || res ≈ tol) + state.res_norm = res end +default_stopping_criterion(tol, iter::SFISTAIteration, state::SFISTAState) = begin + calc_residual!(state, iter) + state.res_norm <= tol || state.res_norm ≈ tol +end default_solution(::SFISTAIteration, state::SFISTAState) = state.y +default_iteration_summary(it, iter::SFISTAIteration, state::SFISTAState) = + ("" => it, (iter.termination_type == :AIPP ? "‖∂_η(f + h)(y)‖" : "‖∇f(y) + ∂h(y)‖") => state.res_norm) """ SFISTA(; ) @@ -130,11 +138,12 @@ See also: [`SFISTAIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=10_000`: maximum number of iteration - `tol::1e-6`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=100`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `SFISTAIteration` constructor upon call # References @@ -147,13 +156,12 @@ See also: [`SFISTAIteration`](@ref), [`IterativeAlgorithm`](@ref). SFISTA(; maxit = 10_000, tol = 1e-6, - termination_type = "", - stop = (iter, state) -> check_sc(state, iter, tol, termination_type)[2], + stop = (iter, state) -> default_stopping_criterion(tol, iter, state), solution = default_solution, verbose = false, freq = 100, - display = (it, iter, state) -> - @printf("%5d | %.3e\n", it, check_sc(state, iter, tol, termination_type)[1]), + summary = default_iteration_summary, + display = default_display, kwargs..., ) = IterativeAlgorithm( SFISTAIteration; @@ -162,6 +170,7 @@ SFISTA(; solution, verbose, freq, + summary, display, kwargs..., ) diff --git a/src/algorithms/zerofpr.jl b/src/algorithms/zerofpr.jl index 8830969..03d0b0d 100644 --- a/src/algorithms/zerofpr.jl +++ b/src/algorithms/zerofpr.jl @@ -222,13 +222,14 @@ end default_stopping_criterion(tol, ::ZeroFPRIteration, state::ZeroFPRState) = norm(state.res, Inf) / state.gamma <= tol default_solution(::ZeroFPRIteration, state::ZeroFPRState) = state.xbar -default_display(it, ::ZeroFPRIteration, state::ZeroFPRState) = @printf( - "%5d | %.3e | %.3e | %.3e\n", - it, - state.gamma, - norm(state.res, Inf) / state.gamma, - state.tau, -) +default_iteration_summary(it, ::ZeroFPRIteration, state::ZeroFPRState) = + ("" => it, + "f(Ax)" => state.f_Ax, + "g(x̄)" => state.g_xbar, + "γ" => state.gamma, + "‖x - x̄‖/γ" => norm(state.res, Inf) / state.gamma, + "τ" => state.tau, + ) """ ZeroFPR(; ) @@ -249,11 +250,12 @@ See also: [`ZeroFPRIteration`](@ref), [`IterativeAlgorithm`](@ref). # Arguments - `maxit::Int=1_000`: maximum number of iteration - `tol::1e-8`: tolerance for the default stopping criterion -- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration -- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution +- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration +- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution - `verbose::Bool=false`: whether the algorithm state should be displayed -- `freq::Int=10`: every how many iterations to display the algorithm state -- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state +- `freq::Int=10`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed. +- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state +- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state - `kwargs...`: additional keyword arguments to pass on to the `ZeroFPRIteration` constructor upon call # References @@ -266,6 +268,7 @@ ZeroFPR(; solution = default_solution, verbose = false, freq = 10, + summary = default_iteration_summary, display = default_display, kwargs..., ) = IterativeAlgorithm( @@ -275,6 +278,7 @@ ZeroFPR(; solution, verbose, freq, + summary, display, kwargs..., )