Skip to content

warning about weight_g/weight_v missing on WeightNorm on PyTorch#32194

Closed
kamilakesbi wants to merge 4 commits intohuggingface:mainfrom
kamilakesbi:fix_wav2vec2_weight_g_weight_v
Closed

warning about weight_g/weight_v missing on WeightNorm on PyTorch#32194
kamilakesbi wants to merge 4 commits intohuggingface:mainfrom
kamilakesbi:fix_wav2vec2_weight_g_weight_v

Conversation

@kamilakesbi
Copy link
Copy Markdown
Contributor

@kamilakesbi kamilakesbi commented Jul 24, 2024

This PR addresses issue #26796:

  • On models using WeightNorm (such as Hubert or Wav2vec2), a warning message appears indicating that some weights have not been properly initialised (weight_v and weight_g parameters).

This is because Pytorch migrates the weight_v and weight_g params of WeightNorm to original0 and original1.

The state dict is converted correctly thanks to PR #24030, but we still get a warning message when loading the model.

We can solve this by replacing weight_g and weight_v keys with corresponding updated keys in _load_pretrained_model as done in this PR.

Who can review:

cc @sanchit-gandhi

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kamilakesbi kamilakesbi changed the title [WIP] - warning about weight_g/weight_v missing on WeightNorm on PyTorch warning about weight_g/weight_v missing on WeightNorm on PyTorch Jul 25, 2024
Copy link
Copy Markdown
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the elegance of this fix @kamilakesbi! My only concern is that it assumes that all models are using the new parametrizations format, e.g. the one used in Wav2Vec2:

weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

Doing a quick search of the codebase, this is not the case for all models, e.g. for EnCodec:

self.conv = nn.utils.weight_norm(self.conv)

=> have we checked the weight norm params are loaded without warning when we use the legacy implementation? Ideally, we would only "fix" these keys if we know we're loading a legacy weight norm param into a parametrization one

@kamilakesbi kamilakesbi force-pushed the fix_wav2vec2_weight_g_weight_v branch from 642251d to 4a16f9b Compare July 26, 2024 10:23
@kamilakesbi
Copy link
Copy Markdown
Contributor Author

@sanchit-gandhi there are indeed a few models for which the new parametrizations format is missing:

  • Univnet
  • sew_d
  • sew
  • Seamless_m4t
  • Seamless_m4t v2
  • fast_speech2_conformer
  • Encodec

I've udpated the corresponding modeling_files.

The problem may arise with future model integrations! I'm thinking of the recent dac for example.

We should probably add a warning suggesting to users to implement the new parameterisation format with weight_norm when adding a new model!

@amyeroberts
Copy link
Copy Markdown
Contributor

cc @ylacombe

@ylacombe
Copy link
Copy Markdown
Contributor

Superseded by #33275

@ylacombe ylacombe closed this Sep 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants