warning about weight_g/weight_v missing on WeightNorm on PyTorch#32194
warning about weight_g/weight_v missing on WeightNorm on PyTorch#32194kamilakesbi wants to merge 4 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sanchit-gandhi
left a comment
There was a problem hiding this comment.
I like the elegance of this fix @kamilakesbi! My only concern is that it assumes that all models are using the new parametrizations format, e.g. the one used in Wav2Vec2:
transformers/src/transformers/models/wav2vec2/modeling_wav2vec2.py
Lines 377 to 379 in 9d6c064
Doing a quick search of the codebase, this is not the case for all models, e.g. for EnCodec:
=> have we checked the weight norm params are loaded without warning when we use the legacy implementation? Ideally, we would only "fix" these keys if we know we're loading a legacy weight norm param into a parametrization one
642251d to
4a16f9b
Compare
|
@sanchit-gandhi there are indeed a few models for which the new parametrizations format is missing:
I've udpated the corresponding modeling_files. The problem may arise with future model integrations! I'm thinking of the recent dac for example. We should probably add a warning suggesting to users to implement the new parameterisation format with weight_norm when adding a new model! |
|
cc @ylacombe |
|
Superseded by #33275 |
This PR addresses issue #26796:
weight_vandweight_gparameters).This is because Pytorch migrates the
weight_vandweight_gparams of WeightNorm tooriginal0andoriginal1.The state dict is converted correctly thanks to PR #24030, but we still get a warning message when loading the model.
We can solve this by replacing
weight_gandweight_vkeys with corresponding updated keys in_load_pretrained_modelas done in this PR.Who can review:
cc @sanchit-gandhi