Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions src/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using ADTypes: ADTypes
using DifferentiationInterface: DifferentiationInterface
using ProximalCore
using ProximalCore: prox, prox!
using Printf

const RealOrComplex{R} = Union{R,Complex{R}}
const Maybe{T} = Union{T,Nothing}
Expand Down Expand Up @@ -55,18 +56,19 @@ include("accel/noaccel.jl")

# algorithm interface

struct IterativeAlgorithm{IteratorType,H,S,D,K}
struct IterativeAlgorithm{IteratorType,H,S,I,D,K}
maxit::Int
stop::H
solution::S
verbose::Bool
freq::Int
summary::I
display::D
kwargs::K
end

"""
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...)
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...)

Wrapper for an iterator type `T`, adding termination and verbosity options on top of it.

Expand All @@ -75,7 +77,7 @@ The resulting "algorithm" object `alg` can be called on a set of keyword argumen
to `kwargs` and passed on to `T` to construct an iterator which will be looped over.
Specifically, if an algorithm is constructed as

alg = IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...)
alg = IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...)

then calling it with

Expand All @@ -88,7 +90,7 @@ will internally loop over an iterator constructed as
# Note
This constructor is not meant to be used directly: instead, algorithm-specific constructors
should be defined on top of it and exposed to the user, that set appropriate default functions
for `stop`, `solution`, `display`.
for `stop`, `solution`, `summary`, `display`.

# Arguments
* `T::Type`: iterator type to use
Expand All @@ -97,28 +99,78 @@ for `stop`, `solution`, `display`.
* `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
* `verbose::Bool`: whether the algorithm state should be displayed
* `freq::Int`: every how many iterations to display the algorithm state
* `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
* `summary::Function`: function returning a summary of the iteration state, `summary(k::Int, iter::T, state)` should return a vector of pairs `(name, value)`
* `display::Function`: display function, `display(k::Int, alg, iter::T, state)` should display a summary of the iteration state
* `kwargs...`: keyword arguments to pass on to `T` when constructing the iterator
"""
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, display, kwargs...) =
IterativeAlgorithm{T,typeof(stop),typeof(solution),typeof(display),typeof(kwargs)}(
IterativeAlgorithm(T; maxit, stop, solution, verbose, freq, summary, display, kwargs...) =
IterativeAlgorithm{T,typeof(stop),typeof(solution),typeof(summary),typeof(display),typeof(kwargs)}(
maxit,
stop,
solution,
verbose,
freq,
summary,
display,
kwargs,
)

function default_display(k, alg, iter, state, printfunc=println)
if alg.freq > 0
summary = alg.summary(k, iter, state)
column_widths = map(pair -> max(length(pair.first), pair.second isa Integer ? 5 : 9), summary)
if k == 0
keys = map(first, summary)
first_line = [_get_centered_text(key, width) for (width, key) in zip(column_widths, keys)]
printfunc(join(first_line, " | "))
second_line = [repeat('-', width) for width in column_widths]
printfunc(join(second_line, "-|-"), "-")
else
values = map(last, summary)
parts = [_format_value(value, width) for (width, value) in zip(column_widths, values)]
printfunc(join(parts, " | "))
end
else
summary = alg.summary(k, iter, state)
if summary[1].first == ""
summary = ("total iterations" => k, summary[2:end]...)
end
items = map(pair -> @sprintf("%s=%s", pair.first, _format_value(pair.second, 0)), summary)
printfunc(join(items, ", "))
end
end

function _get_centered_text(text, width)
l = length(text)
if l >= width
return text
end
left_padding = div(width - l, 2)
right_padding = width - l - left_padding
return repeat(" ", left_padding) * text * repeat(" ", right_padding)
end

function _format_value(value, width)
if value isa Integer
return @sprintf("%*d", width, value)
elseif value isa Float64 || value isa Float32
return @sprintf("%*.3e", width, value)
else
return @sprintf("%*s", width, string(value))
end
end

function (alg::IterativeAlgorithm{IteratorType})(; kwargs...) where {IteratorType}
iter = IteratorType(; alg.kwargs..., kwargs...)
for (k, state) in enumerate(iter)
if k == 1 && alg.verbose && alg.freq > 0
alg.display(0, alg, iter, state)
end
if k >= alg.maxit || alg.stop(iter, state)
alg.verbose && alg.display(k, iter, state)
alg.verbose && alg.display(k, alg, iter, state)
return (alg.solution(iter, state), k)
end
alg.verbose && mod(k, alg.freq) == 0 && alg.display(k, iter, state)
alg.verbose && alg.freq > 0 && mod(k, alg.freq) == 0 && alg.display(k, alg, iter, state)
end
end

Expand Down
23 changes: 14 additions & 9 deletions src/algorithms/davis_yin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ end

Base.IteratorSize(::Type{<:DavisYinIteration}) = Base.IsInfinite()

struct DavisYinState{T}
struct DavisYinState{T,R}
z::T
xg::T
f_xg::R
grad_f_xg::T
z_half::T
xh::T
g_xh::R
res::T
end

Expand All @@ -58,10 +60,10 @@ function Base.iterate(iter::DavisYinIteration)
xg, = prox(iter.g, z, iter.gamma)
f_xg, grad_f_xg = value_and_gradient(iter.f, xg)
z_half = 2 .* xg .- z .- iter.gamma .* grad_f_xg
xh, = prox(iter.h, z_half, iter.gamma)
xh, g_xh = prox(iter.h, z_half, iter.gamma)
res = xh - xg
z .+= iter.lambda .* res
state = DavisYinState(z, xg, grad_f_xg, z_half, xh, res)
state = DavisYinState(z, xg, f_xg, grad_f_xg, z_half, xh, g_xh, res)
return state, state
end

Expand All @@ -79,8 +81,8 @@ end
default_stopping_criterion(tol, ::DavisYinIteration, state::DavisYinState) =
norm(state.res, Inf) <= tol
default_solution(::DavisYinIteration, state::DavisYinState) = state.xh
default_display(it, ::DavisYinIteration, state::DavisYinState) =
@printf("%5d | %.3e\n", it, norm(state.res, Inf))
default_iteration_summary(it, ::DavisYinIteration, state::DavisYinState) =
("" => it, "f(xg)" => state.f_xg, "g(xh)" => state.g_xh, "‖xg - xh‖" => norm(state.res, Inf))

"""
DavisYin(; <keyword-arguments>)
Expand All @@ -101,11 +103,12 @@ See also: [`DavisYinIteration`](@ref), [`IterativeAlgorithm`](@ref).
# Arguments
- `maxit::Int=10_000`: maximum number of iteration
- `tol::1e-8`: tolerance for the default stopping criterion
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
- `verbose::Bool=false`: whether the algorithm state should be displayed
- `freq::Int=100`: every how many iterations to display the algorithm state
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `kwargs...`: additional keyword arguments to pass on to the `DavisYinIteration` constructor upon call

# References
Expand All @@ -118,6 +121,7 @@ DavisYin(;
solution = default_solution,
verbose = false,
freq = 100,
summary=default_iteration_summary,
display = default_display,
kwargs...,
) = IterativeAlgorithm(
Expand All @@ -127,6 +131,7 @@ DavisYin(;
solution,
verbose,
freq,
summary,
display,
kwargs...,
)
24 changes: 15 additions & 9 deletions src/algorithms/douglas_rachford.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,26 @@ end

Base.IteratorSize(::Type{<:DouglasRachfordIteration}) = Base.IsInfinite()

Base.@kwdef struct DouglasRachfordState{Tx}
Base.@kwdef struct DouglasRachfordState{Tx,R}
x::Tx
y::Tx = similar(x)
f_y::R = real(eltype(x))(0)
r::Tx = similar(x)
z::Tx = similar(x)
g_z::R = real(eltype(x))(0)
res::Tx = similar(x)
end

function Base.iterate(
iter::DouglasRachfordIteration,
state::DouglasRachfordState = DouglasRachfordState(x = copy(iter.x0)),
)
prox!(state.y, iter.f, state.x, iter.gamma)
f_y = prox!(state.y, iter.f, state.x, iter.gamma)
state.r .= 2 .* state.y .- state.x
prox!(state.z, iter.g, state.r, iter.gamma)
g_z = prox!(state.z, iter.g, state.r, iter.gamma)
state.res .= state.y .- state.z
state.x .-= state.res
state = DouglasRachfordState(state.x, state.y, f_y, state.r, state.z, g_z, state.res)
return state, state
end

Expand All @@ -68,8 +71,8 @@ default_stopping_criterion(
state::DouglasRachfordState,
) = norm(state.res, Inf) / iter.gamma <= tol
default_solution(::DouglasRachfordIteration, state::DouglasRachfordState) = state.y
default_display(it, iter::DouglasRachfordIteration, state::DouglasRachfordState) =
@printf("%5d | %.3e\n", it, norm(state.res, Inf) / iter.gamma)
default_iteration_summary(it, iter::DouglasRachfordIteration, state::DouglasRachfordState) =
("" => it, "f(y)" => state.f_y, "g(z)" => state.g_z, "‖y - z‖" => norm(state.res, Inf) / iter.gamma)

"""
DouglasRachford(; <keyword-arguments>)
Expand All @@ -88,11 +91,12 @@ See also: [`DouglasRachfordIteration`](@ref), [`IterativeAlgorithm`](@ref).
# Arguments
- `maxit::Int=1_000`: maximum number of iteration
- `tol::1e-8`: tolerance for the default stopping criterion
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
- `verbose::Bool=false`: whether the algorithm state should be displayed
- `freq::Int=100`: every how many iterations to display the algorithm state
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `kwargs...`: additional keyword arguments to pass on to the `DouglasRachfordIteration` constructor upon call

# References
Expand All @@ -105,6 +109,7 @@ DouglasRachford(;
solution = default_solution,
verbose = false,
freq = 100,
summary = default_iteration_summary,
display = default_display,
kwargs...,
) = IterativeAlgorithm(
Expand All @@ -114,6 +119,7 @@ DouglasRachford(;
solution,
verbose,
freq,
summary,
display,
kwargs...,
)
25 changes: 14 additions & 11 deletions src/algorithms/drls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ end
default_stopping_criterion(tol, ::DRLSIteration, state::DRLSState) =
norm(state.res, Inf) / state.gamma <= tol
default_solution(::DRLSIteration, state::DRLSState) = state.v
default_display(it, ::DRLSIteration, state::DRLSState) = @printf(
"%5d | %.3e | %.3e | %.3e\n",
it,
state.gamma,
norm(state.res, Inf) / state.gamma,
state.tau,
)
default_iteration_summary(it, ::DRLSIteration, state::DRLSState) =
("" => it,
"f(u)" => state.f_u,
"g(v)" => state.g_v,
"γ" => state.gamma,
"‖u - v‖/γ" => norm(state.res, Inf) / state.gamma,
"τ" => state.tau)

"""
DRLS(; <keyword-arguments>)
Expand All @@ -224,11 +224,12 @@ See also: [`DRLSIteration`](@ref), [`IterativeAlgorithm`](@ref).
# Arguments
- `maxit::Int=1_000`: maximum number of iteration
- `tol::1e-8`: tolerance for the default stopping criterion
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
- `verbose::Bool=false`: whether the algorithm state should be displayed
- `freq::Int=10`: every how many iterations to display the algorithm state
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `freq::Int=10`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `kwargs...`: additional keyword arguments to pass on to the `DRLSIteration` constructor upon call

# References
Expand All @@ -241,6 +242,7 @@ DRLS(;
solution = default_solution,
verbose = false,
freq = 10,
summary = default_iteration_summary,
display = default_display,
kwargs...,
) = IterativeAlgorithm(
Expand All @@ -250,6 +252,7 @@ DRLS(;
solution,
verbose,
freq,
summary,
display,
kwargs...,
)
20 changes: 14 additions & 6 deletions src/algorithms/fast_forward_backward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,13 @@ default_stopping_criterion(
state::FastForwardBackwardState,
) = norm(state.res, Inf) / state.gamma <= tol
default_solution(::FastForwardBackwardIteration, state::FastForwardBackwardState) = state.z
default_display(it, ::FastForwardBackwardIteration, state::FastForwardBackwardState) =
@printf("%5d | %.3e | %.3e\n", it, state.gamma, norm(state.res, Inf) / state.gamma)
default_iteration_summary(it, iter::FastForwardBackwardIteration, state::FastForwardBackwardState) = begin
if iter.adaptive
("" => it, "f(x)" => state.f_x, "g(z)" => state.g_z, "γ" => state.gamma, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma)
else
("" => it, "f(x)" => state.f_x, "g(z)" => state.g_z, "‖x - z‖/γ" => norm(state.res, Inf) / state.gamma)
end
end

"""
FastForwardBackward(; <keyword-arguments>)
Expand All @@ -172,11 +177,12 @@ See also: [`FastForwardBackwardIteration`](@ref), [`IterativeAlgorithm`](@ref).
# Arguments
- `maxit::Int=10_000`: maximum number of iteration
- `tol::1e-8`: tolerance for the default stopping criterion
- `stop::Function`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function`: solution mapping, `solution(::T, state)` should return the identified solution
- `stop::Function=(iter, state) -> default_stopping_criterion(tol, iter, state)`: termination condition, `stop(::T, state)` should return `true` when to stop the iteration
- `solution::Function=default_solution`: solution mapping, `solution(::T, state)` should return the identified solution
- `verbose::Bool=false`: whether the algorithm state should be displayed
- `freq::Int=100`: every how many iterations to display the algorithm state
- `display::Function`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `freq::Int=100`: every how many iterations to display the algorithm state. If `freq <= 0`, only the final iteration is displayed.
- `summary::Function=default_iteration_summary`: function to generate iteration summaries, `summary(::Int, iter::T, state)` should return a summary of the iteration state
- `display::Function=default_display`: display function, `display(::Int, ::T, state)` should display a summary of the iteration state
- `kwargs...`: additional keyword arguments to pass on to the `FastForwardBackwardIteration` constructor upon call

# References
Expand All @@ -190,6 +196,7 @@ FastForwardBackward(;
solution = default_solution,
verbose = false,
freq = 100,
summary = default_iteration_summary,
display = default_display,
kwargs...,
) = IterativeAlgorithm(
Expand All @@ -199,6 +206,7 @@ FastForwardBackward(;
solution,
verbose,
freq,
summary,
display,
kwargs...,
)
Expand Down
Loading