Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions ignite/handlers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def _run(
optimizer: Optimizer,
output_transform: Callable,
num_iter: int,
start_lr: float,
end_lr: float,
start_lr: Union[float, List[float]],
end_lr: Union[float, List[float]],
step_mode: str,
smooth_f: float,
diverge_th: float,
Expand Down Expand Up @@ -118,13 +118,35 @@ def _run(
)

self.logger.debug(f"Running LR finder for {num_iter} iterations")
if start_lr is None:
start_lr = optimizer.param_groups[0]["lr"]
# Initialize the proper learning rate policy
if step_mode.lower() == "exp":
if start_lr is None:
start_lr_list = [pg["lr"] for pg in optimizer.param_groups]
elif isinstance(start_lr, float):
start_lr_list = [start_lr] * len(optimizer.param_groups)
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr, num_iter))
elif isinstance(start_lr, list):
if len(start_lr) != len(optimizer.param_groups):
raise ValueError(
f"Number of values of start_lr should be equal to optimizer values. start_lr values:{len(start_lr)} optimizer values: {len(optimizer.param_groups)}"
)
start_lr_list = start_lr
else:
raise TypeError(f"start_lr should a float or list of floats, but given {type(start_lr)}")
if isinstance(end_lr, float):
end_lr_list = [end_lr] * len(optimizer.param_groups)
elif isinstance(end_lr, list):
if len(end_lr) != len(optimizer.param_groups):
raise ValueError(
f"Values of end_lr should be equal to optimizer values. end_lr values:{len(end_lr)} optimizer values: {len(optimizer.param_groups)}"
)
end_lr_list = end_lr
else:
raise TypeError(f"end_lr should a float or list of floats, but given {type(end_lr)}")

if step_mode.lower() == "exp":
self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, start_lr_list, end_lr_list, num_iter))
else:
if isinstance(start_lr, list) or isinstance(end_lr, list):
raise NotImplementedError()
self._lr_schedule = PiecewiseLinear(
optimizer, param_name="lr", milestones_values=[(0, start_lr), (num_iter, end_lr)]
)
Expand Down Expand Up @@ -367,8 +389,8 @@ def attach(
to_save: Mapping,
output_transform: Callable = lambda output: output,
num_iter: Optional[int] = None,
start_lr: Optional[float] = None,
end_lr: float = 10.0,
start_lr: Optional[Union[float, List[float]]] = None,
end_lr: Union[None, float, List[float]] = 10.0,
step_mode: str = "exp",
smooth_f: float = 0.05,
diverge_th: float = 5.0,
Expand Down Expand Up @@ -498,4 +520,4 @@ def __init__(self, optimizer: Optimizer, start_lr: List[float], end_lr: float, n
def get_lr(self) -> List[float]: # type: ignore
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
return [base_lr * (end_lr / base_lr) ** r for end_lr, base_lr in zip(self.end_lr, self.base_lrs)]
6 changes: 6 additions & 0 deletions tests/ignite/handlers/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def test_lr_policy(lr_finder, to_save, dummy_engine, dataloader):
lr = lr_finder.get_results()["lr"]
assert all([lr[i - 1] < lr[i] for i in range(1, len(lr))])

def test_multi_group_optimizer(lr_finder, dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, dataloader):
with lr_finder.attach(dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, start_lr=[0.1, 0.1, 0.01], end_lr=[1.0, 1.0, 0.1]) as trainer_with_lr_finder:
trainer_with_lr_finder.run(dataloader)

lr = lr_finder.get_results()["lr"]
assert all([lr[i - 1] < lr[i] for i in range(1, len(lr))])

def assert_output_sizes(lr_finder, dummy_engine):
iteration = dummy_engine.state.iteration
Expand Down