Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0b1845e
#1913 Add any parameter scheduling to enlarge the actual scope of Pa…
fco-dv Jun 28, 2021
7f99f47
code style fixes
fco-dv Jun 29, 2021
b7ff68c
update doc
fco-dv Jun 29, 2021
b73e61a
fix docstring warnings
fco-dv Jun 30, 2021
ff40ab8
Merge branch 'master' into any_parameter_scheduler
fco-dv Jun 30, 2021
3cc0fa3
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 3, 2021
bf5893c
Rename Any* to State* classes
fco-dv Jul 7, 2021
ca1829f
rm useless import
fco-dv Jul 7, 2021
216a667
fix docstring
fco-dv Jul 7, 2021
ba2e61e
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 8, 2021
c5e0ba7
Introduce BaseParamScheduler class for State* Optimizer* parameters s…
fco-dv Jul 14, 2021
72d7ac3
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 15, 2021
b9dc427
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 19, 2021
5ad920b
Naming changes.
fco-dv Jul 20, 2021
d0d1863
fix flake8 errors
fco-dv Jul 20, 2021
d9c81e5
fix docstring / parametrize tests
fco-dv Jul 21, 2021
590832c
naming changes
fco-dv Jul 23, 2021
fa2d759
parametrize tests
fco-dv Jul 29, 2021
61a6ce0
fix flake8
fco-dv Jul 29, 2021
c09d903
try to remove lines in pytest configs
fco-dv Jul 30, 2021
928ca43
Update ignite/handlers/state_param_scheduler.py
fco-dv Jul 30, 2021
e58b9dc
LinearState to PwLinearState ( implemented from PiecewiseLinear Param…
fco-dv Jul 31, 2021
3fb0ee1
Update ignite/handlers/state_param_scheduler.py
fco-dv Aug 1, 2021
8f93b70
Update ignite/handlers/state_param_scheduler.py
fco-dv Aug 1, 2021
48aee27
Re-naming : PwLinearStateScheduler to PiecewiseLinearStateScheduler
fco-dv Aug 1, 2021
1e797be
add test for state_dict and docstring examples
fco-dv Aug 2, 2021
29da974
Merge remote-tracking branch 'upstream/master' into any_parameter_sch…
fco-dv Aug 13, 2021
796d6cd
Merge remote-tracking branch 'upstream/master' into any_parameter_sch…
fco-dv Sep 27, 2021
eac90fd
Merge branch 'master' into any_parameter_scheduler
sdesrozis Sep 27, 2021
0bcd9bd
Update ignite/handlers/state_param_scheduler.py
fco-dv Sep 28, 2021
3dafb74
improve docstring / change lambda_fn to lambda_obj for LambdaStateSch…
fco-dv Sep 30, 2021
8549811
rm duplicated test
fco-dv Sep 30, 2021
154c6b0
fix code fmt
fco-dv Sep 30, 2021
49c1e13
Merge remote-tracking branch 'upstream/master' into any_parameter_sch…
fco-dv Oct 1, 2021
0b9a1d8
add test LambdaState object must be callable
fco-dv Oct 1, 2021
349b9ce
add test on asserts
fco-dv Oct 1, 2021
218b6c1
Apply suggestions from code review
vfdev-5 Oct 11, 2021
6ba9b38
Merge branch 'master' into any_parameter_scheduler
vfdev-5 Oct 11, 2021
330be74
autopep8 fix
vfdev-5 Oct 11, 2021
4152f02
Apply suggestions from code review
vfdev-5 Oct 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Parameter scheduler
:nosignatures:
:toctree: generated

StateParameterScheduler
LambdaStateParameterScheduler
LinearStateParameterScheduler
ExponentialStateParameterScheduler
StepStateParameterScheduler
MultiStepStateParameterScheduler
ConcatScheduler
CosineAnnealingScheduler
CyclicalScheduler
Expand Down
263 changes: 262 additions & 1 deletion ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numbers
import tempfile
from abc import ABCMeta, abstractmethod
from bisect import bisect_right
from collections import OrderedDict
from copy import copy
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast

import torch
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -226,6 +227,256 @@ def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any:
return ax


class StateParameterScheduler(ParamScheduler):
"""An abstract class for updating a state parameter during
training not belonging to an Optimizer parameter group.

Args:
parameter_setter: function that sets the required parameter value.
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).

.. versionadded:: 0.6.0
"""

def __init__(
self, parameter_setter: Callable, param_name: str, save_history: bool = False,
):
self.param_name = param_name
self.event_index = 0
self._save_history = save_history
self.parameter_setter = parameter_setter
self._state_attrs = ["event_index", "param_name", "save_history", "parameter_setter"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
self.event_index += 1
value = self.get_param()
self.parameter_setter(value)
if self.save_history and engine:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(self.param_name, []) # type: ignore[attr-defined]
engine.state.param_history[self.param_name].append(value) # type: ignore[attr-defined]

@classmethod
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]:
"""Method to simulate scheduled values during `num_events` events.

Args:
num_events: number of events during the simulation.
scheduler_kwargs: parameter scheduler configuration kwargs.

Returns:
event_index, value

Examples:

.. code-block:: python

lr_values = np.array(LinearCyclicalScheduler.simulate_values(num_events=50, param_name='lr',
start_value=1e-1, end_value=1e-3,
cycle_size=10))

plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()

"""
keys_to_remove = ["parameter_setter", "save_history"]
for key in keys_to_remove:
if key in scheduler_kwargs:
del scheduler_kwargs[key]
values = []
scheduler = cls(parameter_setter=_get_fake_param_setter(), save_history=False, **scheduler_kwargs)
for i in range(num_events):
scheduler(engine=None)
values.append([i, scheduler.parameter_setter.__closure__[0].cell_contents]) # type: ignore[attr-defined]
return values


class LambdaStateParameterScheduler(StateParameterScheduler):
"""Update a parameter during training by using a user defined function.
User defined function is taking an event index as input and returns a float value.

Args:
parameter_setter: function that sets the required parameter value.
lambda_fn: user defined parameter update function.
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).

.. versionadded:: 0.6.0
"""

def __init__(self, parameter_setter: Callable, lambda_fn: Callable, param_name: str, save_history: bool = False):
super(LambdaStateParameterScheduler, self).__init__(parameter_setter, param_name, save_history)
self.lambda_fn = lambda_fn
self._state_attrs += ["lambda_fn"]

def get_param(self) -> Union[List[float], float]:
return self.lambda_fn(self.event_index)


class LinearStateParameterScheduler(StateParameterScheduler):
"""Update a parameter during training by using linear function.
The function keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1
until an additional step_one steps passed. Continues the trend until it reaches max_value.

Args:
initial_value : starting value of the parameter.
step_constant : step index until parameter's value is kept constant.
step_max_value : step index at which parameter's value becomes max_value.
max_value : max parameter value.
parameter_setter: function that sets the required parameter value.
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).

.. versionadded:: 0.6.0
"""

def __init__(
self,
initial_value: float,
step_constant: int,
step_max_value: int,
max_value: float,
parameter_setter: Callable,
param_name: str,
save_history: bool = False,
):
super(LinearStateParameterScheduler, self).__init__(parameter_setter, param_name, save_history)
self.initial_value = initial_value
self.step_constant = step_constant
self.step_max_value = step_max_value
self.max_value = max_value
self._state_attrs += ["initial_value", "step_constant", "step_max_value", "max_value"]

def get_param(self) -> Union[List[float], float]:
if self.event_index <= self.step_constant:
delta = 0.0
elif self.event_index > self.step_max_value:
delta = self.max_value - self.initial_value
else:
delta = (
(self.max_value - self.initial_value)
/ (self.step_max_value - self.step_constant)
* (self.event_index - self.step_constant)
)

return self.initial_value + delta


class ExponentialStateParameterScheduler(StateParameterScheduler):
"""Update a parameter during training by using exponential function.
The function decays the parameter value by gamma every step.
Based on the closed form of ExponentialLR from Pytorch
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L457

Args:
parameter_setter: function that sets the required parameter value.
initial_value: Starting value of the parameter.
gamma: Multiplicative factor of parameter value decay.
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).

.. versionadded:: 0.6.0
"""

def __init__(
self,
parameter_setter: Callable,
initial_value: float,
gamma: float,
param_name: str,
save_history: bool = False,
):
super(ExponentialStateParameterScheduler, self).__init__(parameter_setter, param_name, save_history)
self.initial_value = initial_value
self.gamma = gamma
self._state_attrs += ["initial_value", "gamma"]

def get_param(self) -> Union[List[float], float]:
return self.initial_value * self.gamma ** self.event_index


class StepStateParameterScheduler(StateParameterScheduler):
"""Update a parameter during training by using a step function.
This function decays the parameter value by gamma every step_size.
Based on StepLR from Pytorch.
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L377

Args:
parameter_setter: function that sets the required parameter value.
initial_value: Starting value of the parameter.
gamma: Multiplicative factor of parameter value decay.
step_size: Period of parameter value decay.
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).

.. versionadded:: 0.6.0
"""

def __init__(
self,
parameter_setter: Callable,
initial_value: float,
gamma: float,
step_size: int,
param_name: str,
save_history: bool = False,
):
super(StepStateParameterScheduler, self).__init__(parameter_setter, param_name, save_history)
self.initial_value = initial_value
self.gamma = gamma
self.step_size = step_size
self._state_attrs += ["initial_value", "gamma", "step_size"]

def get_param(self) -> Union[List[float], float]:
return self.initial_value * self.gamma ** (self.event_index // self.step_size)


class MultiStepStateParameterScheduler(StateParameterScheduler):
"""Update a parameter during training by using a multi step function.
The function decays the parameter value by gamma once the number of steps reaches one of the milestones.
Based on MultiStepLR from Pytorch.
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L424

Args:
parameter_setter: function that sets the required parameter value.
initial_value: Starting value of the parameter.
gamma: Multiplicative factor of parameter value decay.
milestones: List of step indices. Must be increasing.
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).

.. versionadded:: 0.6.0
"""

def __init__(
self,
parameter_setter: Callable,
initial_value: float,
gamma: float,
milestones: List[int],
param_name: str,
save_history: bool = False,
):
super(MultiStepStateParameterScheduler, self).__init__(parameter_setter, param_name, save_history)
self.initial_value = initial_value
self.gamma = gamma
self.milestones = milestones
self._state_attrs += ["initial_value", "gamma", "milestones"]

def get_param(self) -> Union[List[float], float]:
return self.initial_value * self.gamma ** bisect_right(self.milestones, self.event_index)


class CyclicalScheduler(ParamScheduler):
"""An abstract class for updating an optimizer's parameter value over a
cycle of some size.
Expand Down Expand Up @@ -1135,3 +1386,13 @@ def _get_fake_optimizer(
optimizer_cls = torch.optim.SGD
kwargs["lr"] = 0.01
return optimizer_cls([t], **kwargs)


def _get_fake_param_setter() -> Callable:
inner_param_value = 0.0

def param_setter(param_value: float) -> None:
nonlocal inner_param_value
inner_param_value = param_value

return param_setter
Loading