Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0b1845e
#1913 Add any parameter scheduling to enlarge the actual scope of Pa…
fco-dv Jun 28, 2021
7f99f47
code style fixes
fco-dv Jun 29, 2021
b7ff68c
update doc
fco-dv Jun 29, 2021
b73e61a
fix docstring warnings
fco-dv Jun 30, 2021
ff40ab8
Merge branch 'master' into any_parameter_scheduler
fco-dv Jun 30, 2021
3cc0fa3
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 3, 2021
bf5893c
Rename Any* to State* classes
fco-dv Jul 7, 2021
ca1829f
rm useless import
fco-dv Jul 7, 2021
216a667
fix docstring
fco-dv Jul 7, 2021
ba2e61e
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 8, 2021
c5e0ba7
Introduce BaseParamScheduler class for State* Optimizer* parameters s…
fco-dv Jul 14, 2021
72d7ac3
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 15, 2021
b9dc427
Merge branch 'master' into any_parameter_scheduler
fco-dv Jul 19, 2021
5ad920b
Naming changes.
fco-dv Jul 20, 2021
d0d1863
fix flake8 errors
fco-dv Jul 20, 2021
d9c81e5
fix docstring / parametrize tests
fco-dv Jul 21, 2021
590832c
naming changes
fco-dv Jul 23, 2021
fa2d759
parametrize tests
fco-dv Jul 29, 2021
61a6ce0
fix flake8
fco-dv Jul 29, 2021
c09d903
try to remove lines in pytest configs
fco-dv Jul 30, 2021
928ca43
Update ignite/handlers/state_param_scheduler.py
fco-dv Jul 30, 2021
e58b9dc
LinearState to PwLinearState ( implemented from PiecewiseLinear Param…
fco-dv Jul 31, 2021
3fb0ee1
Update ignite/handlers/state_param_scheduler.py
fco-dv Aug 1, 2021
8f93b70
Update ignite/handlers/state_param_scheduler.py
fco-dv Aug 1, 2021
48aee27
Re-naming : PwLinearStateScheduler to PiecewiseLinearStateScheduler
fco-dv Aug 1, 2021
1e797be
add test for state_dict and docstring examples
fco-dv Aug 2, 2021
29da974
Merge remote-tracking branch 'upstream/master' into any_parameter_sch…
fco-dv Aug 13, 2021
796d6cd
Merge remote-tracking branch 'upstream/master' into any_parameter_sch…
fco-dv Sep 27, 2021
eac90fd
Merge branch 'master' into any_parameter_scheduler
sdesrozis Sep 27, 2021
0bcd9bd
Update ignite/handlers/state_param_scheduler.py
fco-dv Sep 28, 2021
3dafb74
improve docstring / change lambda_fn to lambda_obj for LambdaStateSch…
fco-dv Sep 30, 2021
8549811
rm duplicated test
fco-dv Sep 30, 2021
154c6b0
fix code fmt
fco-dv Sep 30, 2021
49c1e13
Merge remote-tracking branch 'upstream/master' into any_parameter_sch…
fco-dv Oct 1, 2021
0b9a1d8
add test LambdaState object must be callable
fco-dv Oct 1, 2021
349b9ce
add test on asserts
fco-dv Oct 1, 2021
218b6c1
Apply suggestions from code review
vfdev-5 Oct 11, 2021
6ba9b38
Merge branch 'master' into any_parameter_scheduler
vfdev-5 Oct 11, 2021
330be74
autopep8 fix
vfdev-5 Oct 11, 2021
4152f02
Apply suggestions from code review
vfdev-5 Oct 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Complete list of handlers

checkpoint.BaseSaveHandler
param_scheduler.ParamScheduler
state_param_scheduler.StateParamScheduler

.. _param-scheduler-label:

Expand All @@ -43,6 +44,7 @@ Parameter scheduler
:nosignatures:
:toctree: generated

BaseParamScheduler
ConcatScheduler
CosineAnnealingScheduler
CyclicalScheduler
Expand All @@ -53,6 +55,22 @@ Parameter scheduler
PiecewiseLinear
create_lr_scheduler_with_warmup

State Parameter scheduler
-------------------------

.. currentmodule:: ignite.handlers.state_param_scheduler

.. autosummary::
:nosignatures:
:toctree: generated

StateParamScheduler
LambdaStateScheduler
PiecewiseLinearStateScheduler
ExpStateScheduler
StepStateScheduler
MultiStepStateScheduler

More on parameter scheduling
----------------------------

Expand Down
16 changes: 16 additions & 0 deletions ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ignite.handlers.ema_handler import EMAHandler
from ignite.handlers.lr_finder import FastaiLRFinder
from ignite.handlers.param_scheduler import (
BaseParamScheduler,
ConcatScheduler,
CosineAnnealingScheduler,
CyclicalScheduler,
Expand All @@ -17,6 +18,14 @@
PiecewiseLinear,
create_lr_scheduler_with_warmup,
)
from ignite.handlers.state_param_scheduler import (
ExpStateScheduler,
LambdaStateScheduler,
MultiStepStateScheduler,
PiecewiseLinearStateScheduler,
StateParamScheduler,
StepStateScheduler,
)
from ignite.handlers.stores import EpochOutputStore
from ignite.handlers.terminate_on_nan import TerminateOnNan
from ignite.handlers.time_limit import TimeLimit
Expand Down Expand Up @@ -46,6 +55,13 @@
"EMAHandler",
"BasicTimeProfiler",
"HandlersTimeProfiler",
"BaseParamScheduler",
"StateParamScheduler",
"LambdaStateScheduler",
"PiecewiseLinearStateScheduler",
"ExpStateScheduler",
"StepStateScheduler",
"MultiStepStateScheduler",
]


Expand Down
218 changes: 126 additions & 92 deletions ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,83 +15,26 @@
from ignite.engine import Engine


class ParamScheduler(metaclass=ABCMeta):
"""An abstract class for updating an optimizer's parameter value during
class BaseParamScheduler(metaclass=ABCMeta):
r"""An abstract class for updating an engine state or optimizer's parameter value during
training.

Args:
optimizer: torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name: name of optimizer's parameter to update.
param_name: name of engine state or optimizer's parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
param_group_index: optimizer's parameters group to use

Note:
Parameter scheduler works independently of the internal state of the attached optimizer.
More precisely, whatever the state of the optimizer (newly created or used by another scheduler) the scheduler
sets defined absolute values.
.. versionadded:: 0.6.0

.. versionadded:: 0.4.5
"""

def __init__(
self,
optimizer: Optimizer,
param_name: str,
save_history: bool = False,
param_group_index: Optional[int] = None,
self, param_name: str, save_history: bool = False,
):

if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
):
raise TypeError(
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "
f"but given {type(optimizer)}"
)

self.optimizer = optimizer
self.param_group_index = param_group_index
self.param_name = param_name
self.event_index = 0
self._save_history = save_history
self._state_attrs = ["event_index", "param_name", "save_history", "param_group_index"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:

value = self.get_param()

if isinstance(value, list):
if len(value) != len(self.optimizer_param_groups):
raise ValueError(
"size of value is different than optimizer_param_groups "
f"{len(value)} != {len(self.optimizer_param_groups)}"
)

for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value[i]
else:
for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value

if name is None:
name = self.param_name

if self.save_history and engine:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(name, []) # type: ignore[attr-defined]
values = [pg[self.param_name] for pg in self.optimizer_param_groups]
engine.state.param_history[name].append(values) # type: ignore[attr-defined]
self.event_index += 1

@property
def optimizer_param_groups(self) -> List[Dict[str, Any]]:
if self.param_group_index is None:
return self.optimizer.param_groups
return [self.optimizer.param_groups[self.param_group_index]]
self._state_attrs = ["event_index", "param_name", "save_history"]

@property
def save_history(self) -> bool:
Expand All @@ -102,11 +45,11 @@ def save_history(self, value: bool) -> None:
self._save_history = value

def state_dict(self) -> Dict[str, Any]:
"""Returns a dictionary containing a whole state of ParamScheduler.
"""Returns a dictionary containing a whole state of BaseParamScheduler.

Returns:
dict:
a dictionary containing a whole state of ParamScheduler
a dictionary containing a whole state of BaseParamScheduler
"""
destination = OrderedDict()
for name in self._state_attrs:
Expand All @@ -118,7 +61,7 @@ def state_dict(self) -> Dict[str, Any]:
return destination

def load_state_dict(self, state_dict: Mapping) -> None:
"""Copies parameters from :attr:`state_dict` into this ParamScheduler.
"""Copies parameters from :attr:`state_dict` into this BaseParamScheduler.

Args:
state_dict: a dict containing parameters.
Expand All @@ -140,14 +83,15 @@ def load_state_dict(self, state_dict: Mapping) -> None:

@abstractmethod
def get_param(self) -> Union[List[float], float]:
"""Method to get current optimizer's parameter values
"""Method to get current parameter values

Returns:
list of params, or scalar param
"""
pass

@classmethod
@abstractmethod
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]:
"""Method to simulate scheduled values during `num_events` events.

Expand All @@ -157,31 +101,8 @@ def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[

Returns:
event_index, value

Examples:

.. code-block:: python

lr_values = np.array(LinearCyclicalScheduler.simulate_values(num_events=50, param_name='lr',
start_value=1e-1, end_value=1e-3,
cycle_size=10))

plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()

"""
keys_to_remove = ["optimizer", "save_history"]
for key in keys_to_remove:
if key in scheduler_kwargs:
del scheduler_kwargs[key]
values = []
scheduler = cls(optimizer=_get_fake_optimizer(), save_history=False, **scheduler_kwargs)
for i in range(num_events):
scheduler(engine=None)
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
return values
pass

@classmethod
def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any:
Expand Down Expand Up @@ -211,7 +132,7 @@ def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any:
start_value=1e-1, end_value=1e-3, cycle_size=10))
"""
try:
import matplotlib.pylab as plt
import matplotlib.pyplot as plt
except ImportError:
raise RuntimeError(
"This method requires matplotlib to be installed. "
Expand All @@ -226,6 +147,119 @@ def plot_values(cls, num_events: int, **scheduler_kwargs: Mapping) -> Any:
return ax


class ParamScheduler(BaseParamScheduler):
"""An abstract class for updating an optimizer's parameter value during
training.

Args:
optimizer: torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name: name of optimizer's parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
param_group_index: optimizer's parameters group to use

Note:
Parameter scheduler works independently of the internal state of the attached optimizer.
More precisely, whatever the state of the optimizer (newly created or used by another scheduler) the scheduler
sets defined absolute values.

.. versionadded:: 0.5.1

"""

def __init__(
self,
optimizer: Optimizer,
param_name: str,
save_history: bool = False,
param_group_index: Optional[int] = None,
):
super(ParamScheduler, self).__init__(param_name, save_history)
if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
):
raise TypeError(
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "
f"but given {type(optimizer)}"
)

self.optimizer = optimizer
self.param_group_index = param_group_index
self._state_attrs += ["param_group_index"]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:

value = self.get_param()

if isinstance(value, list):
if len(value) != len(self.optimizer_param_groups):
raise ValueError(
"size of value is different than optimizer_param_groups "
f"{len(value)} != {len(self.optimizer_param_groups)}"
)

for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value[i]
else:
for i, param_group in enumerate(self.optimizer_param_groups):
param_group[self.param_name] = value

if name is None:
name = self.param_name

if self.save_history and engine:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(name, []) # type: ignore[attr-defined]
values = [pg[self.param_name] for pg in self.optimizer_param_groups]
engine.state.param_history[name].append(values) # type: ignore[attr-defined]
self.event_index += 1

@property
def optimizer_param_groups(self) -> List[Dict[str, Any]]:
if self.param_group_index is None:
return self.optimizer.param_groups
return [self.optimizer.param_groups[self.param_group_index]]

@classmethod
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any) -> List[List[int]]:
"""Method to simulate scheduled values during `num_events` events.

Args:
num_events: number of events during the simulation.
scheduler_kwargs: parameter scheduler configuration kwargs.

Returns:
event_index, value

Examples:

.. code-block:: python

lr_values = np.array(LinearCyclicalScheduler.simulate_values(num_events=50, param_name='lr',
start_value=1e-1, end_value=1e-3,
cycle_size=10))

plt.plot(lr_values[:, 0], lr_values[:, 1], label="learning rate")
plt.xlabel("events")
plt.ylabel("values")
plt.legend()

"""
keys_to_remove = ["optimizer", "save_history"]
for key in keys_to_remove:
if key in scheduler_kwargs:
del scheduler_kwargs[key]
values = []
scheduler = cls(optimizer=_get_fake_optimizer(), save_history=False, **scheduler_kwargs)
for i in range(num_events):
scheduler(engine=None)
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
return values


class CyclicalScheduler(ParamScheduler):
"""An abstract class for updating an optimizer's parameter value over a
cycle of some size.
Expand Down
Loading