Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,6 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
def _load_state_dict_into_meta_model(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explained here, the issue doesn't appear when doing regular loading of the state dict, but only when doing metaloading!

model,
state_dict,
loaded_state_dict_keys, # left for now but could be removed, see below
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I removed loaded_state_dict_keys from _load_state_dict_into_meta_model, because according to the following snippet, it was not actually used before:

        # First part of the test is always true as load_state_dict_keys always contains state_dict keys.
        if param_name not in loaded_state_dict_keys or param_name not in expected_keys:

I might have overlooked some downside effects, especially with quantization and/or training frameworks. WDYT @ArthurZucker and @LysandreJik ? Who should I tag for more info?

Also happy to change back to the original behaviour

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't break any tests, let's remove it and keep an eye out for eventual breakage

start_prefix,
expected_keys,
device_map=None,
Expand Down Expand Up @@ -847,8 +846,6 @@ def _load_state_dict_into_meta_model(
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.

error_msgs = []

Expand All @@ -868,6 +865,18 @@ def _load_state_dict_into_meta_model(
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta

# To reproduce `_load_state_dict_into_model` behaviour, we need to manually rename parametrized weigth norm, if necessary.
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
new_key = key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
new_key = key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
new_key = key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
new_key = key.replace("parametrizations.weight.original1", "weight_v")
if new_key:
old_keys.append(key)
new_keys.append(new_key)
Expand All @@ -884,8 +893,7 @@ def _load_state_dict_into_meta_model(
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
if param_name not in expected_keys:
continue

if param_name.startswith(start_prefix):
Expand Down Expand Up @@ -4128,6 +4136,18 @@ def _fix_key(key):
return key.replace("beta", "bias")
if "gamma" in key:
return key.replace("gamma", "weight")

# to avoid logging parametrized weight norm renaming
if hasattr(nn.utils.parametrizations, "weight_norm"):
if "weight_g" in key:
return key.replace("weight_g", "parametrizations.weight.original0")
if "weight_v" in key:
return key.replace("weight_v", "parametrizations.weight.original1")
else:
if "parametrizations.weight.original0" in key:
return key.replace("parametrizations.weight.original0", "weight_g")
if "parametrizations.weight.original1" in key:
return key.replace("parametrizations.weight.original1", "weight_v")
return key

original_loaded_keys = loaded_keys
Expand Down Expand Up @@ -4372,7 +4392,6 @@ def _find_mismatched_keys(
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
Expand Down Expand Up @@ -4449,7 +4468,6 @@ def _find_mismatched_keys(
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
Expand Down Expand Up @@ -4605,7 +4623,6 @@ def _load_pretrained_model_low_mem(
error_msgs = _load_state_dict_into_meta_model(
model,
state_dict,
loaded_state_dict_keys,
start_prefix,
expected_keys=expected_keys,
hf_quantizer=hf_quantizer,
Expand Down
44 changes: 24 additions & 20 deletions src/transformers/models/dac/modeling_dac.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,33 +494,37 @@ def _init_weights(self, module):
nn.init.constant_(module.bias, 0)

def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

for layer in self.quantizer.quantizers:
nn.utils.weight_norm(layer.in_proj)
nn.utils.weight_norm(layer.out_proj)
weight_norm(layer.in_proj)
weight_norm(layer.out_proj)

nn.utils.weight_norm(self.encoder.conv1)
nn.utils.weight_norm(self.encoder.conv2)
weight_norm(self.encoder.conv1)
weight_norm(self.encoder.conv2)

for layer in self.encoder.block:
nn.utils.weight_norm(layer.conv1)
nn.utils.weight_norm(layer.res_unit1.conv1)
nn.utils.weight_norm(layer.res_unit1.conv2)
nn.utils.weight_norm(layer.res_unit2.conv1)
nn.utils.weight_norm(layer.res_unit2.conv2)
nn.utils.weight_norm(layer.res_unit3.conv1)
nn.utils.weight_norm(layer.res_unit3.conv2)
weight_norm(layer.conv1)
weight_norm(layer.res_unit1.conv1)
weight_norm(layer.res_unit1.conv2)
weight_norm(layer.res_unit2.conv1)
weight_norm(layer.res_unit2.conv2)
weight_norm(layer.res_unit3.conv1)
weight_norm(layer.res_unit3.conv2)

nn.utils.weight_norm(self.decoder.conv1)
nn.utils.weight_norm(self.decoder.conv2)
weight_norm(self.decoder.conv1)
weight_norm(self.decoder.conv2)

for layer in self.decoder.block:
nn.utils.weight_norm(layer.conv_t1)
nn.utils.weight_norm(layer.res_unit1.conv1)
nn.utils.weight_norm(layer.res_unit1.conv2)
nn.utils.weight_norm(layer.res_unit2.conv1)
nn.utils.weight_norm(layer.res_unit2.conv2)
nn.utils.weight_norm(layer.res_unit3.conv1)
nn.utils.weight_norm(layer.res_unit3.conv2)
weight_norm(layer.conv_t1)
weight_norm(layer.res_unit1.conv1)
weight_norm(layer.res_unit1.conv2)
weight_norm(layer.res_unit2.conv1)
weight_norm(layer.res_unit2.conv2)
weight_norm(layer.res_unit3.conv1)
weight_norm(layer.res_unit3.conv2)

def remove_weight_norm(self):
for layer in self.quantizer.quantizers:
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/encodec/modeling_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ def __init__(
)

self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

if self.norm_type == "weight_norm":
self.conv = nn.utils.weight_norm(self.conv)
self.conv = weight_norm(self.conv)
elif self.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels)

Expand Down Expand Up @@ -186,8 +190,13 @@ def __init__(self, config, in_channels: int, out_channels: int, kernel_size: int
)

self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)

weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

if config.norm_type == "weight_norm":
self.conv = nn.utils.weight_norm(self.conv)
self.conv = weight_norm(self.conv)
elif config.norm_type == "time_group_norm":
self.norm = nn.GroupNorm(1, out_channels)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1416,10 +1416,14 @@ def get_padding(self, kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2

def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)

def remove_weight_norm(self):
for layer in self.convs1:
Expand Down Expand Up @@ -1493,12 +1497,16 @@ def _init_weights(self, module):
module.bias.data.zero_()

def apply_weight_norm(self):
nn.utils.weight_norm(self.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

weight_norm(self.conv_pre)
for layer in self.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.conv_post)
weight_norm(self.conv_post)

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
Expand Down
18 changes: 13 additions & 5 deletions src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,10 +2361,14 @@ def get_padding(self, kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2

def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)

def remove_weight_norm(self):
for layer in self.convs1:
Expand Down Expand Up @@ -2633,12 +2637,16 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()

def apply_weight_norm(self):
nn.utils.weight_norm(self.hifi_gan.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

weight_norm(self.hifi_gan.conv_pre)
for layer in self.hifi_gan.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.hifi_gan.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.hifi_gan.conv_post)
weight_norm(self.hifi_gan.conv_post)

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.hifi_gan.conv_pre)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2608,10 +2608,14 @@ def get_padding(self, kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2

def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)

def remove_weight_norm(self):
for layer in self.convs1:
Expand Down Expand Up @@ -2889,12 +2893,16 @@ def _init_weights(self, module):

# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm
def apply_weight_norm(self):
nn.utils.weight_norm(self.hifi_gan.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

weight_norm(self.hifi_gan.conv_pre)
for layer in self.hifi_gan.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.hifi_gan.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.hifi_gan.conv_post)
weight_norm(self.hifi_gan.conv_post)

# Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.remove_weight_norm
def remove_weight_norm(self):
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,15 @@ def __init__(self, config):
stride=config.squeeze_factor,
)

weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
Expand All @@ -288,7 +292,7 @@ def __init__(self, config):
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)

self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,15 @@ def __init__(self, config):
stride=config.squeeze_factor,
)

weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)
if hasattr(self.conv, "parametrizations"):
weight_g = self.conv.parametrizations.weight.original0
weight_v = self.conv.parametrizations.weight.original1
Expand All @@ -363,7 +367,7 @@ def __init__(self, config):
deepspeed.zero.register_external_parameter(self, weight_v)
deepspeed.zero.register_external_parameter(self, weight_g)
else:
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
self.conv = weight_norm(self.conv, name="weight", dim=2)

self.padding = SEWDSamePadLayer(config.num_conv_pos_embeddings)
self.activation = ACT2FN[config.feat_extract_activation]
Expand Down
18 changes: 13 additions & 5 deletions src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,10 +3234,14 @@ def get_padding(self, kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2

def apply_weight_norm(self):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

for layer in self.convs1:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.convs2:
nn.utils.weight_norm(layer)
weight_norm(layer)

def remove_weight_norm(self):
for layer in self.convs1:
Expand Down Expand Up @@ -3310,12 +3314,16 @@ def _init_weights(self, module):
module.bias.data.zero_()

def apply_weight_norm(self):
nn.utils.weight_norm(self.conv_pre)
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm

weight_norm(self.conv_pre)
for layer in self.upsampler:
nn.utils.weight_norm(layer)
weight_norm(layer)
for layer in self.resblocks:
layer.apply_weight_norm()
nn.utils.weight_norm(self.conv_post)
weight_norm(self.conv_post)

def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
Expand Down
Loading