Fix parametrization-based weight norm#33275
Conversation
…oad_state_dict with classic loading
| def _load_state_dict_into_meta_model( | ||
| model, | ||
| state_dict, | ||
| loaded_state_dict_keys, # left for now but could be removed, see below |
There was a problem hiding this comment.
Here, I removed loaded_state_dict_keys from _load_state_dict_into_meta_model, because according to the following snippet, it was not actually used before:
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:I might have overlooked some downside effects, especially with quantization and/or training frameworks. WDYT @ArthurZucker and @LysandreJik ? Who should I tag for more info?
Also happy to change back to the original behaviour
There was a problem hiding this comment.
If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage
| @@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): | |||
| def _load_state_dict_into_meta_model( | |||
There was a problem hiding this comment.
As explained here, the issue doesn't appear when doing regular loading of the state dict, but only when doing metaloading!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
LysandreJik
left a comment
There was a problem hiding this comment.
This looks good to me in practice for the affected models; @ArthurZucker if you can give it a second look just to confirm or infirm
| def _load_state_dict_into_meta_model( | ||
| model, | ||
| state_dict, | ||
| loaded_state_dict_keys, # left for now but could be removed, see below |
There was a problem hiding this comment.
If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for fixing @ylacombe!
Adding remapping in the loading functions I'm a bit squeamish about, as it causes issues for "gamma" and "beta" but this seems pretty well controlled and an only likely to hit some weights very rarely.
* refactor weight_norm + propose uniformed solution to reconcile meta load_state_dict with classic loading * make style * fix sew * fix sew and sew_d tests
What does this PR do?
Supersedes #32194 and fixes #31970 and #26796!
While #32194 was already a great work, it wasn't compatible with versions of Torch that only had
nn.utils.weight_norm.I'll left a review to explain some choices and to highlight where I'm not quite sure of my solution!
cc @LysandreJik and @ArthurZucker !