Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3484a59
Remove unnecessary code in BaseOutputHandler
sadra-barikbin Jan 22, 2022
2c8eed9
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 3, 2022
ccf2364
Add ReduceLROnPlateauScheduler
sadra-barikbin Feb 3, 2022
7f7dae6
Fix indentation issue
sadra-barikbin Feb 3, 2022
896e482
Fix another indentation issue
sadra-barikbin Feb 3, 2022
cbc8d04
Fix PEP8 related issues
sadra-barikbin Feb 3, 2022
47b0622
Fix other PEP8 related issues
sadra-barikbin Feb 3, 2022
91d058e
Fix hopefully the last PEP8 related issue
sadra-barikbin Feb 3, 2022
9fd7d61
Fix hopefully the last PEP8 related issue
sadra-barikbin Feb 3, 2022
b7dc921
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 3, 2022
e0644e3
Merge branch 'master' of https://github.com/sadra-barikbin/ignite
sadra-barikbin Feb 3, 2022
c95a2be
Remove ReduceLROnPlateau's specific params and add link to it
sadra-barikbin Feb 3, 2022
96554d0
Fix state_dict bug and add a test
sadra-barikbin Feb 5, 2022
145dabc
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 9, 2022
0aee28a
Update docs
sadra-barikbin Feb 10, 2022
307803c
Merge branch 'master' into master
vfdev-5 Feb 14, 2022
0129572
Merge branch 'pytorch:master' into master
sadra-barikbin Feb 14, 2022
a17a5b2
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 19, 2022
b3ea962
Add doctest and fix typo
sadra-barikbin Feb 20, 2022
e2e6831
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 20, 2022
b88c9e1
Merge branch 'master' of https://github.com/pytorch/ignite
sadra-barikbin Feb 20, 2022
8d0ae3c
Merge branch 'master' of https://github.com/sadra-barikbin/ignite
sadra-barikbin Feb 20, 2022
408b271
Merge branch 'master' into master
vfdev-5 Feb 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/defaults.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# create default trainer for doctests
# as handlers could be attached to the trainer,
# each test must defined his own trainer using `.. testsetup:`
# each test must define his own trainer using `.. testsetup:`

def get_default_trainer():

Expand Down
40 changes: 40 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Parameter scheduler
ParamGroupScheduler
ParamScheduler
PiecewiseLinear
ReduceLROnPlateauScheduler
create_lr_scheduler_with_warmup

State Parameter scheduler
Expand Down Expand Up @@ -379,3 +380,42 @@ Concatenate with torch schedulers


.. image:: ./_static/img/schedulers/concat_linear_exp_step_lr.png


Example with :class:`ignite.handlers.param_scheduler.ReduceLROnPlateauScheduler`
`````````````````````````````````````````````````````````````````````````````````````

.. code-block:: python

import matplotlib.pyplot as plt
import numpy as np
from ignite.handlers import ReduceLROnPlateauScheduler

metric_values = [0.7, 0.78, 0.81, 0.82, 0.82, 0.83, 0.80, 0.81, 0.84, 0.78]
num_events = 10
init_lr = 0.1

lr_values = np.array(ReduceLROnPlateauScheduler.simulate_values(
num_events, metric_values, init_lr,
factor=0.5, patience=1, mode='max', threshold=0.01, threshold_mode='abs'
)
)

plt.figure(figsize=(15, 5))
plt.suptitle("ReduceLROnPlateauScheduler")
plt.subplot(121)
plt.plot(lr_values[:, 1], label="learning rate")
plt.xticks(lr_values[:, 0])
plt.xlabel("events")
plt.ylabel("values")
plt.legend()

plt.subplot(122)
plt.plot(metric_values, label="metric")
plt.xticks(lr_values[:, 0])
plt.xlabel("events")
plt.ylabel("values")
plt.legend()


.. image:: ./_static/img/schedulers/reduce_lr_on_plateau_example.png
4 changes: 3 additions & 1 deletion ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional

from ignite.engine import Engine
from ignite.engine.events import Events
Expand All @@ -17,6 +17,7 @@
ParamGroupScheduler,
ParamScheduler,
PiecewiseLinear,
ReduceLROnPlateauScheduler,
)
from ignite.handlers.state_param_scheduler import (
ExpStateScheduler,
Expand Down Expand Up @@ -62,6 +63,7 @@
"ExpStateScheduler",
"StepStateScheduler",
"MultiStepStateScheduler",
"ReduceLROnPlateauScheduler",
]


Expand Down
188 changes: 187 additions & 1 deletion ignite/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import math
import numbers
import tempfile
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from copy import copy
from pathlib import Path
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.optim.optimizer import Optimizer

from ignite.engine import Engine
Expand Down Expand Up @@ -1406,6 +1407,191 @@ def simulate_values(cls, num_events: int, schedulers: List[_LRScheduler], **kwar
return values


class ReduceLROnPlateauScheduler(ParamScheduler):
"""Reduce LR when a metric stops improving.
Wrapper of `torch.optim.lr_scheduler.ReduceLROnPlateau
<https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html>`_.

Args:
optimizer: Wrapped optimizer.
metric_name: metric whose improvement is monitored.
Must be attached to the same engine.
trainer: Trainer engine to log LR history in its
`state.output.param_history`. Is used if `save_history`
is true. Default: None.
save_history: Whether to save history or not. If true,
history will be logged in `trainer`'s `state.output.param_history`.
Default: False.
param_group_index: `optimizer`'s parameters group
to use. Default: None. Use all `optimizer`'s paramater groups.
**scheduler_kwargs: Keyword arguments to be passed to the wrapped
`ReduceLROnPlateau`.

Examples:

.. code-block python

# Metric 'metric-name' should surpass its best value by
# more than 1 unit after at most 2 epochs, otherwise LR
# would get multiplied by 0.5 .

scheduler = ReduceLROnPlateauScheduler(
default_optimizer,
metric_name="metric-name", mode="max",
factor=0.5, patience=1, threshold_mode='abs',
threshold=1, trainer=trainer
)

metric = Accuracy()
default_evaluator.attach(metric, "accuracy")

default_evaluator.add_event_handler(Events.COMPLETED, scheduler)

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

default_trainer = get_default_trainer()

# Metric `loss` should decrease more than
# a tenth of best loss after at most
# three iterations. Then best loss would get
# updated, otherwise lr is multiplied by 2

scheduler = ReduceLROnPlateauScheduler(
default_optimizer, "loss",
save_history=True, mode="min",
factor=0.5, patience=3, threshold_mode='rel',
threshold=0.1, trainer=default_trainer
)

metric_values = iter([10, 5, 3, 4, 4, 4, 5, 1])
default_evaluator.state.metrics = {"loss": None}

@default_trainer.on(Events.ITERATION_COMPLETED)
def set_metric_val():
default_evaluator.state.metrics["loss"] = next(metric_values)

default_evaluator.add_event_handler(Events.COMPLETED, scheduler)

@default_trainer.on(Events.ITERATION_COMPLETED)
def trigger_eval():
default_evaluator.run([0.])

default_trainer.run([0.] * 8)

print(default_trainer.state.param_history["lr"])

.. testoutput::

[[0.1], [0.1], [0.1], [0.1], [0.1], [0.1], [0.05], [0.05]]

.. versionadded:: 0.4.8
"""

def __init__(
self,
optimizer: Optimizer,
metric_name: str,
trainer: Optional[Engine] = None,
save_history: bool = False,
param_group_index: Optional[int] = None,
**scheduler_kwargs: Any,
):
super(ReduceLROnPlateauScheduler, self).__init__(
optimizer, "lr", save_history=save_history, param_group_index=param_group_index
)
self.metric_name = metric_name
self.trainer = trainer
self.optimizer = optimizer

if "min_lr" in scheduler_kwargs and param_group_index is not None:
min_lr = scheduler_kwargs["min_lr"]
if not isinstance(min_lr, float):
raise TypeError(f"When param_group_index is given, min_lr should be a float, but given {type(min_lr)}")
_min_lr = min_lr
min_lr = [0] * len(optimizer.param_groups)
min_lr[param_group_index] = _min_lr
else:
min_lr = 0
_scheduler_kwargs = scheduler_kwargs.copy()
_scheduler_kwargs["min_lr"] = min_lr

if "verbose" in _scheduler_kwargs:
warnings.warn(
"Found verbose=True in provided scheduler_kwargs. "
"It would be set to False. Please use save_history instead."
)
_scheduler_kwargs["verbose"] = False

self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs)
self.scheduler._reduce_lr = self._reduce_lr # type: ignore[attr-defined]

self._state_attrs += ["metric_name", "scheduler"]

def __call__(self, engine: Engine, name: Optional[str] = None) -> None: # type: ignore[override]
if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics:
raise ValueError(
"Argument engine should have in its 'state', attribute 'metrics' "
f"which itself has the metric {self.metric_name}."
)
self.scheduler.step(engine.state.metrics[self.metric_name])
super().__call__(self.trainer, name)

def get_param(self) -> Union[float, List[float]]:
lrs = [pg["lr"] for pg in self.optimizer_param_groups]
return lrs[0] if len(lrs) == 1 else lrs

def _reduce_lr(self, epoch: int) -> None:
for i, param_group in enumerate(self.optimizer_param_groups):
old_lr = float(param_group["lr"])
new_lr = max(old_lr * self.scheduler.factor, self.scheduler.min_lrs[i]) # type: ignore[attr-defined]
if old_lr - new_lr > self.scheduler.eps: # type: ignore[attr-defined]
param_group["lr"] = new_lr

@classmethod
def simulate_values( # type: ignore[override]
cls, num_events: int, metric_values: List[float], init_lr: float, **scheduler_kwargs: Any
) -> List[List[int]]:
"""Method to simulate scheduled values during num_events events.

Args:
num_events: number of events during the simulation.
metric_values: values to change LR based on.
init_lr: initial LR to start with.
scheduler_kwargs: kwargs passed to construct an instance of
:class:`ignite.handlers.param_scheduler.ReduceLROnPlateauScheduler`.

Returns:
event_index, value

"""
if len(metric_values) != num_events:
raise ValueError(
"Length of argument metric_values should be equal to num_events. "
f"{len(metric_values)} != {num_events}"
)

keys_to_remove = ["optimizer", "metric_name", "save_history"]
for key in keys_to_remove:
if key in scheduler_kwargs:
del scheduler_kwargs[key]
values = []
scheduler = cls(
optimizer=_get_fake_optimizer(torch.optim.SGD, lr=init_lr),
metric_name="metric",
save_history=False,
**scheduler_kwargs,
)
engine = Engine(lambda _, __: None)
for i in range(num_events):
engine.state.metrics["metric"] = metric_values[i]
scheduler(engine=engine)
values.append([i, scheduler.optimizer_param_groups[0][scheduler.param_name]])
return values


def _get_fake_optimizer(
optimizer_cls: Optional[Union[Type[Optimizer], Type[torch.optim.SGD]]] = None, **kwargs: Any
) -> Union[Optimizer, torch.optim.SGD]:
Expand Down
86 changes: 86 additions & 0 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ParamGroupScheduler,
ParamScheduler,
PiecewiseLinear,
ReduceLROnPlateauScheduler,
)
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer

Expand Down Expand Up @@ -1302,3 +1303,88 @@ def save_lr(engine):
assert lrs == list(
map(pytest.approx, [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95])
)


def test_reduce_lr_on_plateau_scheduler():
tensor1 = torch.zeros([1], requires_grad=True)
tensor2 = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([{"params": [tensor1]}, {"params": [tensor2]}], lr=1)

data = [0] * 8
max_epochs = 10

trainer = Engine(lambda engine, batch: None)

@trainer.on(Events.EPOCH_COMPLETED)
def evaluate():
evaluator.run(data)

scheduler = ReduceLROnPlateauScheduler(
optimizer,
metric_name="acc",
mode="max",
factor=0.5,
patience=1,
threshold_mode="abs",
threshold=1.99,
min_lr=1e-7,
save_history=True,
trainer=trainer,
param_group_index=0,
)
evaluator = Engine(lambda engine, batch: None)
evaluator.state.metrics = {"acc": 0.0}
generate_acc = iter([3, 7, 7, 9, 10, 11, 8, 8, 4, 7])

@evaluator.on(Events.COMPLETED)
def set_acc():
evaluator.state.metrics["acc"] = next(generate_acc)

evaluator.add_event_handler(Events.COMPLETED, scheduler)

trainer.run(data, max_epochs=max_epochs)

lrs = [param[0] for param in trainer.state.param_history["lr"]]
assert lrs == list(
map(
pytest.approx,
[1, 1, 1, 1, 1, 1, 1, 0.5, 0.5, 0.25],
)
)
assert optimizer.param_groups[1]["lr"] == 1

values = ReduceLROnPlateauScheduler.simulate_values(
5, [10, 9, 9, 9, 8.1], 1.0, save_history=True, factor=0.5, patience=2, threshold=0.1
)
values = np.array(values)[:, 1].tolist()
assert values == list(
map(
pytest.approx,
[1.0, 1.0, 1.0, 0.5, 0.5],
)
)


def test_reduce_lr_on_plateau_scheduler_asserts():
tensor1 = torch.zeros([1], requires_grad=True)
tensor2 = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([{"params": [tensor1]}, {"params": [tensor2]}], lr=1)

with pytest.raises(TypeError, match=r"When param_group_index is given, min_lr should be a float, but given"):
ReduceLROnPlateauScheduler(
optimizer,
metric_name="acc",
min_lr=[1e-7, 1e-8],
param_group_index=0,
)

with pytest.raises(
ValueError, match=r"Argument engine should have in its 'state', attribute 'metrics' which itself has the metric"
):
scheduler = ReduceLROnPlateauScheduler(optimizer, metric_name="acc")
evaluator = Engine(lambda engine, batch: None)
scheduler(evaluator)

with pytest.raises(ValueError, match=r"Length of argument metric_values should be equal to num_events."):
metric_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
ReduceLROnPlateauScheduler.simulate_values(5, metric_values, 0.01)