Skip to content
65 changes: 45 additions & 20 deletions ignite/handlers/state_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import warnings
from bisect import bisect_right
from typing import Any, List, Sequence, Tuple, Union

Expand All @@ -11,8 +12,9 @@ class StateParamScheduler(BaseParamScheduler):

Args:
param_name: name of parameter to update.
save_history: whether to log the parameter values to
`engine.state.param_history`, (default=False).
save_history: whether to log the parameter values to ``engine.state.param_history``, (default=False).
create_new: whether to create ``param_name`` on ``engine.state`` taking into account whether ``param_name``
attribute already exists or not. Overrides existing attribute by default, (default=False).

Note:
Parameter scheduler works independently of the internal state of the attached engine.
Expand All @@ -23,10 +25,9 @@ class StateParamScheduler(BaseParamScheduler):

"""

def __init__(
self, param_name: str, save_history: bool = False,
):
def __init__(self, param_name: str, save_history: bool = False, create_new: bool = False):
super(StateParamScheduler, self).__init__(param_name, save_history)
self.create_new = create_new

def attach(
self,
Expand All @@ -43,11 +44,19 @@ def attach(

"""
if hasattr(engine.state, self.param_name):
raise ValueError(
f"Attribute: '{self.param_name}' is already defined in the Engine.state."
f"This may be a conflict between multiple StateParameterScheduler handlers."
f"Please choose another name."
)
if self.create_new:
raise ValueError(
f"Attribute '{self.param_name}' already exists in the engine.state. "
f"This may be a conflict between multiple handlers. "
f"Please choose another name."
)
else:
if not self.create_new:
warnings.warn(
f"Attribute '{self.param_name}' is not defined in the engine.state. "
f"{type(self).__name__} will create it. Remove this warning by setting create_new=True."
)
setattr(engine.state, self.param_name, None)

if self.save_history:
if not hasattr(engine.state, "param_history") or engine.state.param_history is None: # type: ignore
Expand Down Expand Up @@ -147,8 +156,8 @@ def __call__(self, event_index):

"""

def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history)
def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False):
super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new)

if not callable(lambda_obj):
raise ValueError("Expected lambda_obj to be callable.")
Expand Down Expand Up @@ -199,9 +208,13 @@ class PiecewiseLinearStateScheduler(StateParamScheduler):
"""

def __init__(
self, milestones_values: List[Tuple[int, float]], param_name: str, save_history: bool = False,
self,
milestones_values: List[Tuple[int, float]],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history)
super(PiecewiseLinearStateScheduler, self).__init__(param_name, save_history, create_new)

if not isinstance(milestones_values, Sequence):
raise TypeError(
Expand Down Expand Up @@ -289,9 +302,9 @@ class ExpStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False,
self, initial_value: float, gamma: float, param_name: str, save_history: bool = False, create_new: bool = False,
):
super(ExpStateScheduler, self).__init__(param_name, save_history)
super(ExpStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self._state_attrs += ["initial_value", "gamma"]
Expand Down Expand Up @@ -337,9 +350,15 @@ class StepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, step_size: int, param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
step_size: int,
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(StepStateScheduler, self).__init__(param_name, save_history)
super(StepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.step_size = step_size
Expand Down Expand Up @@ -386,9 +405,15 @@ class MultiStepStateScheduler(StateParamScheduler):
"""

def __init__(
self, initial_value: float, gamma: float, milestones: List[int], param_name: str, save_history: bool = False,
self,
initial_value: float,
gamma: float,
milestones: List[int],
param_name: str,
save_history: bool = False,
create_new: bool = False,
):
super(MultiStepStateScheduler, self).__init__(param_name, save_history)
super(MultiStepStateScheduler, self).__init__(param_name, save_history, create_new)
self.initial_value = initial_value
self.gamma = gamma
self.milestones = milestones
Expand Down
93 changes: 68 additions & 25 deletions tests/ignite/handlers/test_state_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
from unittest.mock import patch

import pytest
import torch
import torch.nn as nn

from ignite.engine import Engine, Events
from ignite.handlers.state_param_scheduler import (
Expand Down Expand Up @@ -281,31 +283,6 @@ def _test(scheduler_cls, scheduler_kwargs):
_test(scheduler_cls, scheduler_kwargs)


@pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6])
def test_state_param_asserts(scheduler_cls, scheduler_kwargs):
import re

def _test(scheduler_cls, scheduler_kwargs):
scheduler = scheduler_cls(**scheduler_kwargs)
with pytest.raises(
ValueError,
match=r"Attribute: '"
+ re.escape(scheduler_kwargs["param_name"])
+ "' is already defined in the Engine.state.This may be a conflict between multiple StateParameterScheduler"
+ " handlers.Please choose another name.",
):

trainer = Engine(lambda engine, batch: None)
event = Events.EPOCH_COMPLETED
max_epochs = 2
data = [0] * 10
scheduler.attach(trainer, event)
trainer.run(data, max_epochs=max_epochs)
scheduler.attach(trainer, event)

_test(scheduler_cls, scheduler_kwargs)


def test_torch_save_load():

lambda_state_parameter_scheduler = LambdaStateScheduler(
Expand Down Expand Up @@ -441,3 +418,69 @@ def __call__(self, event_index):
param_scheduler.attach(engine, Events.EPOCH_COMPLETED)

engine.run([0] * 8, max_epochs=10)


def test_param_scheduler_attach_exception():
trainer = Engine(lambda e, b: None)
param_name = "state_param"

setattr(trainer.state, param_name, None)

save_history = True
create_new = True

param_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name,
milestones_values=[(0, 0.0), (10, 0.999)],
save_history=save_history,
create_new=create_new,
)

with pytest.raises(
ValueError,
match=r"Attribute '" + re.escape(param_name) + "' already exists in the engine.state. "
r"This may be a conflict between multiple handlers. "
r"Please choose another name.",
):
param_scheduler.attach(trainer, Events.ITERATION_COMPLETED)


def test_param_scheduler_attach_warning():
trainer = Engine(lambda e, b: None)
param_name = "state_param"
save_history = True
create_new = False

param_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name,
milestones_values=[(0, 0.0), (10, 0.999)],
save_history=save_history,
create_new=create_new,
)

with pytest.warns(
UserWarning,
match=r"Attribute '" + re.escape(param_name) + "' is not defined in the engine.state. "
r"PiecewiseLinearStateScheduler will create it. Remove this warning by setting create_new=True.",
):
param_scheduler.attach(trainer, Events.ITERATION_COMPLETED)


def test_param_scheduler_with_ema_handler():

from ignite.handlers import EMAHandler

model = nn.Linear(2, 1)
trainer = Engine(lambda e, b: model(b))
data = torch.rand(100, 2)

param_name = "ema_decay"

ema_handler = EMAHandler(model)
ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED)

ema_decay_scheduler = PiecewiseLinearStateScheduler(
param_name=param_name, milestones_values=[(0, 0.0), (10, 0.999),], save_history=True
)
ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)
trainer.run(data, max_epochs=20)