-
-
Notifications
You must be signed in to change notification settings - Fork 694
Paramscheduler emahandler #2326
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
0ea2c56
de4cbd4
e1f9565
aae8f97
6046ab8
6194733
4a11472
f37866f
c42f016
930b5b7
9ff67b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -276,31 +276,6 @@ def _test(scheduler_cls, scheduler_kwargs): | |
| _test(scheduler_cls, scheduler_kwargs) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("scheduler_cls,scheduler_kwargs", [config3, config4, config5, config6]) | ||
| def test_state_param_asserts(scheduler_cls, scheduler_kwargs): | ||
| import re | ||
|
|
||
| def _test(scheduler_cls, scheduler_kwargs): | ||
| scheduler = scheduler_cls(**scheduler_kwargs) | ||
| with pytest.raises( | ||
| ValueError, | ||
| match=r"Attribute: '" | ||
| + re.escape(scheduler_kwargs["param_name"]) | ||
| + "' is already defined in the Engine.state.This may be a conflict between multiple StateParameterScheduler" | ||
| + " handlers.Please choose another name.", | ||
| ): | ||
|
|
||
| trainer = Engine(lambda engine, batch: None) | ||
| event = Events.EPOCH_COMPLETED | ||
| max_epochs = 2 | ||
| data = [0] * 10 | ||
| scheduler.attach(trainer, event) | ||
| trainer.run(data, max_epochs=max_epochs) | ||
| scheduler.attach(trainer, event) | ||
|
|
||
| _test(scheduler_cls, scheduler_kwargs) | ||
|
|
||
|
|
||
| def test_torch_save_load(): | ||
|
|
||
| lambda_state_parameter_scheduler = LambdaStateScheduler( | ||
|
|
@@ -414,3 +389,58 @@ def __call__(self, event_index): | |
| param_scheduler.attach(engine, Events.EPOCH_COMPLETED) | ||
|
|
||
| engine.run([0] * 8, max_epochs=10) | ||
|
|
||
|
|
||
| def test_param_scheduler_with_ema_handler_attach_exception(): | ||
| import torch.nn as nn | ||
|
|
||
| from ignite.handlers import EMAHandler | ||
|
|
||
| data = torch.rand(100, 2) | ||
| model = nn.Linear(2, 1) | ||
| trainer = Engine(lambda e, b: model(b)) | ||
| param_name = "ema_decay" | ||
| save_history = True | ||
| create_new = True | ||
|
|
||
| ema_handler = EMAHandler(model) | ||
| ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED) | ||
| ema_decay_scheduler = PiecewiseLinearStateScheduler( | ||
| param_name=param_name, | ||
| milestones_values=[(0, 0.0), (10 * len(data), 0.999)], | ||
| save_history=save_history, | ||
| create_new=create_new, | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you do not need to add EMAHandler to check
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a working test with EMAHandler and PiecewiseLinearStateScheduler like in the issue ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes sure , thanks for your comments @vfdev-5 ! |
||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=r"Attribute: `ema_decay` already exists in the engine.state. " | ||
| r"This may be a conflict between multiple handlers. " | ||
| r"Please choose another name.", | ||
| ): | ||
| ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED) | ||
|
|
||
|
|
||
| def test_param_scheduler_attach_warning(): | ||
| import torch.nn as nn | ||
|
|
||
| data = torch.rand(100, 2) | ||
| model = nn.Linear(2, 1) | ||
| trainer = Engine(lambda e, b: model(b)) | ||
| param_name = "ema_decay" | ||
| save_history = True | ||
| create_new = False | ||
|
|
||
| ema_decay_scheduler = PiecewiseLinearStateScheduler( | ||
| param_name=param_name, | ||
| milestones_values=[(0, 0.0), (10 * len(data), 0.999)], | ||
| save_history=save_history, | ||
| create_new=create_new, | ||
| ) | ||
|
|
||
| with pytest.warns( | ||
| UserWarning, | ||
| match=r"Attribute: `ema_decay` is not defined in the engine.state. " | ||
| r"PiecewiseLinearStateScheduler will create it. Remove this warning by setting `create_new=True`.", | ||
| ): | ||
| ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED) | ||
Uh oh!
There was an error while loading. Please reload this page.