Skip to content
67 changes: 46 additions & 21 deletions ignite/handlers/state_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import warnings
from bisect import bisect_right
from typing import Any, List, Sequence, Tuple, Union

Expand All @@ -11,8 +12,9 @@ class StateParamScheduler(BaseParamScheduler):

Args:
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
save_history: whether to log the parameter values to ``engine.state.param_history``, (default=False).
create_new: whether to create ``param_name`` on ``engine.state`` taking into account whether ``param_name``
attribute already exists or not. Overrides existing attribute by default, (default=False).

Note:
Parameter scheduler works independently of the internal state of the attached engine.
Expand All @@ -23,10 +25,9 @@ class StateParamScheduler(BaseParamScheduler):

"""

def __init__(
self, param_name: str, save_history: bool = False,
):
def __init__(self, param_name: str, save_history: bool = False, create_new: bool = False):
super(StateParamScheduler, self).__init__(param_name, save_history)
self.create_new = create_new

def attach(
self,
Expand All @@ -43,16 +44,24 @@ def attach(

"""
if hasattr(engine.state, self.param_name):
raise ValueError(
f"Attribute: '{self.param_name}' is already defined in the Engine.state."
f"This may be a conflict between multiple StateParameterScheduler handlers."
f"Please choose another name."
)
if self.create_new:
raise ValueError(
f"Attribute '{self.param_name}' already exists in the engine.state. "
f"This may be a conflict between multiple handlers. "
f"Please choose another name."
)
else:
if not self.create_new:
warnings.warn(
f"Attribute '{self.param_name}' is not defined in the engine.state. "
f"{type(self).__name__} will create it. Remove this warning by setting create_new=True."
)
setattr(engine.state, self.param_name, None)

if self.save_history:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
setattr(engine.state, "param_history", {})
engine.state.param_history.setdefault(self.param_name, []) # type: ignore[attr-defined]
engine.state.param_history.setdefault(self.param_name, []) # type: ignore[attr-defined]

engine.add_event_handler(event, self)

Expand Down Expand Up @@ -147,8 +156,8 @@ def __call__(self, event_index):

"""

def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history)
def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)

if not callable(lambda_obj):
raise ValueError("Expected lambda_obj to be callable.")
Expand Down Expand Up @@ -199,9 +208,13 @@ class PiecewiseLinearStateScheduler(StateParamScheduler):
"""

def __init__(
self, milestones_values: List[Tuple[int, float]], param_name: str, save_history: bool = False,
self,
milestones_values: List[Tuple[int, float]],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history)
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history, create_new)

if not isinstance(milestones_values, Sequence):
raise TypeError(
Expand Down Expand Up @@ -289,9 +302,9 @@ class ExpStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False,
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False, create_new: bool = False,
):
super(ExpStateScheduler, self).__init__(param_name, save_history)
super(ExpStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self._state_attrs += ["initial_value", "gamma"]
Expand Down Expand Up @@ -337,9 +350,15 @@ class StepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, step_size: int, param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
step_size: int,
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(StepStateScheduler, self).__init__(param_name, save_history)
super(StepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.step_size = step_size
Expand Down Expand Up @@ -386,9 +405,15 @@ class MultiStepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, milestones: List[int], param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
milestones: List[int],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(MultiStepStateScheduler, self).__init__(param_name, save_history)
super(MultiStepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.milestones = milestones
Expand Down
80 changes: 55 additions & 25 deletions tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,31 +276,6 @@ def _test(scheduler_cls, scheduler_kwargs):
_test(scheduler_cls, scheduler_kwargs)


@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6])
def test_state_param_asserts(scheduler_cls, scheduler_kwargs):
import re

def _test(scheduler_cls, scheduler_kwargs):
scheduler = scheduler_cls(**scheduler_kwargs)
with pytest.raises(
ValueError,
match=r"Attribute: '"
+ re.escape(scheduler_kwargs["param_name"])
+ "' is already defined in the Engine.state.This may be a conflict between multiple StateParameterScheduler"
+ " handlers.Please choose another name.",
):

trainer = Engine(lambda engine, batch: None)
event = Events.EPOCH_COMPLETED
max_epochs = 2
data = [0] * 10
scheduler.attach(trainer, event)
trainer.run(data, max_epochs=max_epochs)
scheduler.attach(trainer, event)

_test(scheduler_cls, scheduler_kwargs)


def test_torch_save_load():

lambda_state_parameter_scheduler = LambdaStateScheduler(
Expand Down Expand Up @@ -414,3 +389,58 @@ def __call__(self, event_index):
param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

engine.run([0] * 8, max_epochs=10)


def test_param_scheduler_with_ema_handler_attach_exception():
import torch.nn as nn

from ignite.handlers import EMAHandler

data = torch.rand(100, 2)
model = nn.Linear(2, 1)
trainer = Engine(lambda e, b: model(b))
param_name = "ema_decay"
save_history = True
create_new = True

ema_handler = EMAHandler(model)
ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED)
ema_decay_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name,
milestones_values=[(0, 0.0), (10 * len(data), 0.999)],
save_history=save_history,
create_new=create_new,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you do not need to add EMAHandler to check ValueError.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a working test with EMAHandler and PiecewiseLinearStateScheduler like in the issue ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sure , thanks for your comments @vfdev-5 !


with pytest.raises(
ValueError,
match=r"Attribute 'ema_decay' already exists in the engine.state. "
r"This may be a conflict between multiple handlers. "
r"Please choose another name.",
):
ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)


def test_param_scheduler_attach_warning():
import torch.nn as nn

data = torch.rand(100, 2)
model = nn.Linear(2, 1)
trainer = Engine(lambda e, b: model(b))
param_name = "ema_decay"
save_history = True
create_new = False

ema_decay_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name,
milestones_values=[(0, 0.0), (10 * len(data), 0.999)],
save_history=save_history,
create_new=create_new,
)

with pytest.warns(
UserWarning,
match=r"Attribute 'ema_decay' is not defined in the engine.state. "
r"PiecewiseLinearStateScheduler will create it. Remove this warning by setting create_new=True.",
):
ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)