Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions ignite/handlers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint
from ignite.handlers.param_scheduler import LRScheduler, PiecewiseLinear
from ignite.handlers.param_scheduler import LRScheduler, ParamGroupScheduler, PiecewiseLinear


class FastaiLRFinder:
Expand Down Expand Up @@ -87,8 +87,8 @@ def _run(
optimizer: Optimizer,
output_transform: Callable,
num_iter: int,
start_lr: float,
end_lr: float,
start_lr: Union[None, float, List[float]],
end_lr: Union[float, List[float]],
step_mode: str,
smooth_f: float,
diverge_th: float,
Expand Down Expand Up @@ -119,14 +119,41 @@ def _run(

self.logger.debug(f"Running LR finder for {num_iter} iterations")
if start_lr is None:
start_lr = optimizer.param_groups[0]["lr"]
start_lr_list = [pg["lr"] for pg in optimizer.param_groups]
elif isinstance(start_lr, float):
start_lr_list = [start_lr] * len(optimizer.param_groups)
elif isinstance(start_lr, list):
if len(start_lr) != len(optimizer.param_groups):
start_error_message = "Number of values of start_lr should be equal to optimizer values."
value_message = f"start_lr values:{len(start_lr)} optimizer values: {len(optimizer.param_groups)}"
raise ValueError(f"{start_error_message} {value_message}")
start_lr_list = start_lr
else:
raise TypeError(f"start_lr should a float or list of floats, but given {type(start_lr)}")
if isinstance(end_lr, float):
end_lr_list = [end_lr] * len(optimizer.param_groups)
elif isinstance(end_lr, list):
if len(end_lr) != len(optimizer.param_groups):
end_error_message = "Number of values end_lr should be equal to optimizer values."
value_message = f"end_lr values:{len(end_lr)} optimizer values: {len(optimizer.param_groups)}"
raise ValueError(f"{end_error_message} {value_message}")
end_lr_list = end_lr
else:
raise TypeError(f"end_lr should a float or list of floats, but given {type(end_lr)}")
# Initialize the proper learning rate policy
if step_mode.lower() == "exp":
start_lr_list = [start_lr] * len(optimizer.param_groups)
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr, num_iter))
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr_list, num_iter))
else:
self._lr_schedule = PiecewiseLinear(
optimizer, param_name="lr", milestones_values=[(0, start_lr), (num_iter, end_lr)]
self._lr_schedule = ParamGroupScheduler(
[
PiecewiseLinear(
optimizer,
param_name="lr",
milestones_values=[(0, start_lr_list[i]), (num_iter, end_lr_list[i])],
param_group_index=i,
)
for i in range(len(optimizer.param_groups))
]
)
if not trainer.has_event_handler(self._lr_schedule):
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)
Expand Down Expand Up @@ -251,7 +278,6 @@ def plot(
)
if not self._history:
raise RuntimeError("learning rate finder didn't run yet so results can't be plotted")

if skip_start < 0:
raise ValueError("skip_start cannot be negative")
if skip_end < 0:
Expand Down Expand Up @@ -367,8 +393,8 @@ def attach(
to_save: Mapping,
output_transform: Callable = lambda output: output,
num_iter: Optional[int] = None,
start_lr: Optional[float] = None,
end_lr: float = 10.0,
start_lr: Optional[Union[float, List[float]]] = None,
end_lr: Optional[Union[float, List[float]]] = 10.0,
step_mode: str = "exp",
smooth_f: float = 0.05,
diverge_th: float = 5.0,
Expand Down Expand Up @@ -498,4 +524,4 @@ def __init__(self, optimizer: Optimizer, start_lr: List[float], end_lr: float, n
def get_lr(self) -> List[float]: # type: ignore
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
return [base_lr * (end_lr / base_lr) ** r for end_lr, base_lr in zip(self.end_lr, self.base_lrs)]
12 changes: 12 additions & 0 deletions tests/ignite/handlers/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,18 @@ def test_lr_policy(lr_finder, to_save, dummy_engine, dataloader):
assert all([lr[i - 1] < lr[i] for i in range(1, len(lr))])


@pytest.mark.parametrize("step_mode", ["exp", "linear"])
def test_multi_opt(lr_finder, dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, dataloader, step_mode):
start_lr = [0.1, 0.1, 0.01]
end_lr = [1.0, 1.0, 1.0]
dummy_engine = dummy_engine_mulitple_param_groups
Copy link
Collaborator

@sadra-barikbin sadra-barikbin Oct 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to tell you calling ./tests/run_code_style.sh fmt applies formatting rules automatically and there were no need to redefine variables with short names. Applying that command would result in:

def test_multiple_optimizers(
    lr_finder, dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, dataloader, step_mode
):
    start_lr = [0.1, 0.1, 0.01]
    end_lr = [1.0, 1.0, 1.0]

    with lr_finder.attach(
        dummy_engine_mulitple_param_groups,
        to_save_mulitple_param_groups,
        start_lr=start_lr,
        end_lr=end_lr,
        step_mode=step_mode,
    ) as trainer:
        trainer.run(dataloader)

Copy link
Collaborator

@sadra-barikbin sadra-barikbin Oct 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way for more information on Python style guides you could read PEP 8.

to_save = to_save_mulitple_param_groups
with lr_finder.attach(dummy_engine, to_save, start_lr=start_lr, end_lr=end_lr, step_mode=step_mode) as trainer:
trainer.run(dataloader)
groups_lrs = lr_finder.get_results()["lr"]
assert [all([group_lrs[i - 1] < group_lrs[i] for i in range(1, len(group_lrs))]) for group_lrs in groups_lrs]


def assert_output_sizes(lr_finder, dummy_engine):
iteration = dummy_engine.state.iteration
lr_finder_results = lr_finder.get_results()
Expand Down