-
Notifications
You must be signed in to change notification settings - Fork 32.7k
Fix parametrization-based weight norm #33275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b139abe
42a4568
51edcb5
9256965
46e9008
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): | |
| def _load_state_dict_into_meta_model( | ||
| model, | ||
| state_dict, | ||
| loaded_state_dict_keys, # left for now but could be removed, see below | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, I removed # First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:I might have overlooked some downside effects, especially with quantization and/or training frameworks. WDYT @ArthurZucker and @LysandreJik ? Who should I tag for more info? Also happy to change back to the original behaviour
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage |
||
| start_prefix, | ||
| expected_keys, | ||
| device_map=None, | ||
|
|
@@ -847,8 +846,6 @@ def _load_state_dict_into_meta_model( | |
| # - deepspeed zero 3 support | ||
| # - need to copy metadata if any - see _load_state_dict_into_model | ||
| # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() | ||
| # - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case | ||
| # they won't get loaded. | ||
|
|
||
| error_msgs = [] | ||
|
|
||
|
|
@@ -868,6 +865,18 @@ def _load_state_dict_into_meta_model( | |
| # We add only the first key as an example | ||
| new_key = key.replace("beta", "bias") | ||
| renamed_beta[key] = new_key if not renamed_beta else renamed_beta | ||
|
|
||
| # To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary. | ||
| if hasattr(nn.utils.parametrizations, "weight_norm"): | ||
| if "weight_g" in key: | ||
| new_key = key.replace("weight_g", "parametrizations.weight.original0") | ||
| if "weight_v" in key: | ||
| new_key = key.replace("weight_v", "parametrizations.weight.original1") | ||
| else: | ||
| if "parametrizations.weight.original0" in key: | ||
| new_key = key.replace("parametrizations.weight.original0", "weight_g") | ||
| if "parametrizations.weight.original1" in key: | ||
| new_key = key.replace("parametrizations.weight.original1", "weight_v") | ||
| if new_key: | ||
| old_keys.append(key) | ||
| new_keys.append(new_key) | ||
|
|
@@ -884,8 +893,7 @@ def _load_state_dict_into_meta_model( | |
| is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") | ||
|
|
||
| for param_name, param in state_dict.items(): | ||
| # First part of the test is always true as load_state_dict_keys always contains state_dict keys. | ||
| if param_name not in loaded_state_dict_keys or param_name not in expected_keys: | ||
| if param_name not in expected_keys: | ||
| continue | ||
|
|
||
| if param_name.startswith(start_prefix): | ||
|
|
@@ -4128,6 +4136,18 @@ def _fix_key(key): | |
| return key.replace("beta", "bias") | ||
| if "gamma" in key: | ||
| return key.replace("gamma", "weight") | ||
|
|
||
| # to avoid logging parametrized weight norm renaming | ||
| if hasattr(nn.utils.parametrizations, "weight_norm"): | ||
| if "weight_g" in key: | ||
| return key.replace("weight_g", "parametrizations.weight.original0") | ||
| if "weight_v" in key: | ||
| return key.replace("weight_v", "parametrizations.weight.original1") | ||
| else: | ||
| if "parametrizations.weight.original0" in key: | ||
| return key.replace("parametrizations.weight.original0", "weight_g") | ||
| if "parametrizations.weight.original1" in key: | ||
| return key.replace("parametrizations.weight.original1", "weight_v") | ||
| return key | ||
|
|
||
| original_loaded_keys = loaded_keys | ||
|
|
@@ -4372,7 +4392,6 @@ def _find_mismatched_keys( | |
| error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( | ||
| model_to_load, | ||
| state_dict, | ||
| loaded_keys, | ||
| start_prefix, | ||
| expected_keys, | ||
| device_map=device_map, | ||
|
|
@@ -4449,7 +4468,6 @@ def _find_mismatched_keys( | |
| new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( | ||
| model_to_load, | ||
| state_dict, | ||
| loaded_keys, | ||
| start_prefix, | ||
| expected_keys, | ||
| device_map=device_map, | ||
|
|
@@ -4605,7 +4623,6 @@ def _load_pretrained_model_low_mem( | |
| error_msgs = _load_state_dict_into_meta_model( | ||
| model, | ||
| state_dict, | ||
| loaded_state_dict_keys, | ||
| start_prefix, | ||
| expected_keys=expected_keys, | ||
| hf_quantizer=hf_quantizer, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explained here, the issue doesn't appear when doing regular loading of the state dict, but only when doing metaloading!