-
-
Notifications
You must be signed in to change notification settings - Fork 666
Fixed parameter scheduler bug with CosineAnnealingWarmRestarts
#2938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
bb5e244
f56b362
5801dd5
34e77e7
86e564f
2f75b92
fcb555c
ccef9c2
37a9102
178c420
1cef268
b2897a8
3f7dd99
aea674d
8c7cebc
01636e2
bdcbad4
8ca28fd
92af29b
45da45b
96cb1dc
4a72e07
ea8b803
83d22d9
6e02a0f
165ed36
5c9b99d
4445b02
971fc64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| from torch.optim.lr_scheduler import ExponentialLR, StepLR | ||
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR, StepLR | ||
|
|
||
| from ignite.engine import Engine, Events | ||
| from ignite.handlers.param_scheduler import ( | ||
|
|
@@ -1362,3 +1362,44 @@ def test_reduce_lr_on_plateau_scheduler_asserts(): | |
| with pytest.raises(ValueError, match=r"Length of argument metric_values should be equal to num_events."): | ||
| metric_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] | ||
| ReduceLROnPlateauScheduler.simulate_values(5, metric_values, 0.01) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("warmup_end_value", [0.23, None]) | ||
| def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value): | ||
| lr = 0.2 | ||
| steps = 100 | ||
| warm_steps = 50 | ||
| warm_start = 0.023 | ||
|
|
||
| def get_optim(): | ||
| t1 = torch.zeros([1], requires_grad=True) | ||
| return torch.optim.SGD([t1], lr=lr) | ||
|
|
||
| def get_cos_shed(): | ||
| return CosineAnnealingWarmRestarts(optimizer, T_0=12, T_mult=3, verbose=False) | ||
|
|
||
| optimizer = get_optim() | ||
| scheduler = get_cos_shed() | ||
| cosine_lrs = [] | ||
| for i in range(steps): | ||
| cosine_lrs.append(optimizer.param_groups[0]["lr"]) | ||
| scheduler.step() | ||
|
|
||
| optimizer = get_optim() | ||
| scheduler = create_lr_scheduler_with_warmup( | ||
| get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps | ||
| ) | ||
|
|
||
| warm_lrs = [] | ||
| for epoch in range(warm_steps + steps): | ||
| scheduler(None) | ||
| warm_lrs.append(optimizer.param_groups[0]["lr"]) | ||
|
|
||
| if warmup_end_value is not None: | ||
| assert ( | ||
| np.linspace(warm_start, warmup_end_value, warm_steps).round(3) == np.array(warm_lrs[:warm_steps]).round(3) | ||
| ).all() | ||
vfdev-5 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert warm_lrs[warm_steps:] == cosine_lrs | ||
| else: | ||
| assert (np.linspace(warm_start, lr, warm_steps).round(3) == np.array(warm_lrs[:warm_steps]).round(3)).all() | ||
| assert warm_lrs[warm_steps - 1 : -1] == cosine_lrs | ||
|
||
Uh oh!
There was an error while loading. Please reload this page.