From c2010c0b59108376b0d4e895bce13830fd0c750a Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 12 Nov 2020 11:12:39 +0100 Subject: [PATCH 01/11] New TF loading weights --- src/transformers/modeling_tf_pytorch_utils.py | 2 +- src/transformers/modeling_tf_utils.py | 70 ++++++++----------- 2 files changed, 32 insertions(+), 40 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index adcc19c61be2..3eaea4b0bd2c 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -280,7 +280,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs if tf_inputs is not None: tf_model(tf_inputs, training=False) # Make sure model is built - load_tf_weights(tf_model, tf_checkpoint_path) + _, _ = load_tf_weights(tf_model, tf_checkpoint_path) return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 2de2b1f0eecb..3ec97bc9f4c2 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -236,16 +236,14 @@ def compute_loss(self, labels, logits): return loss_fn(next_sentence_label, next_sentence_reduced_logits) -def detect_tf_missing_unexpected_layers(model, resolved_archive_file): +def load_tf_weights(model, resolved_archive_file): """ Detect missing and unexpected layers. - Args: model (:obj:`tf.keras.models.Model`): The model to load the weights into. resolved_archive_file (:obj:`str`): The location of the H5 file. - Returns: Two lists, one for the missing layers, and another one for the unexpected layers. """ @@ -253,46 +251,32 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file): unexpected_layers = [] with h5py.File(resolved_archive_file, "r") as f: - saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) - model_layer_names = set(layer.name for layer in model.layers) - missing_layers = list(model_layer_names - saved_layer_names) - unexpected_layers = list(saved_layer_names - model_layer_names) + saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) + model_layers_name_value = {} for layer in model.layers: - if layer.name in saved_layer_names: - g = f[layer.name] - saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") - saved_weight_names_set = set( - "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names - ) - symbolic_weights = layer.trainable_weights + layer.non_trainable_weights - symbolic_weights_names = set( - "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights - ) - missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) - unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) + name = layer.name + model_layers_name_value[name] = layer - return missing_layers, unexpected_layers + model_layers_name = set(model_layers_name_value.keys()) + renamed_saved_h5_model_layers_names = set() + for layer_name in saved_h5_model_layers_name: + name = layer_name -def load_tf_weights(model, resolved_archive_file): - """ - Load the TF weights from a H5 file. + renamed_saved_h5_model_layers_names.add(name) - Args: - model (:obj:`tf.keras.models.Model`): - The model to load the weights into. - resolved_archive_file (:obj:`str`): - The location of the H5 file. - """ - with h5py.File(resolved_archive_file, "r") as f: - saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) + missing_layers = list(model_layers_name - renamed_saved_h5_model_layers_names) + unexpected_layers = list(renamed_saved_h5_model_layers_names - model_layers_name) + saved_weight_names_set = set() + symbolic_weights_names = set() weight_value_tuples = [] - for layer in model.layers: - if layer.name in saved_layer_names: - g = f[layer.name] + for layer_name in saved_h5_model_layers_name: + if layer_name in model_layers_name: + g = f[layer_name] saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") + layer = model_layers_name_value[layer_name] symbolic_weights = layer.trainable_weights + layer.non_trainable_weights saved_weight_names_values = {} @@ -300,13 +284,16 @@ def load_tf_weights(model, resolved_archive_file): name = "/".join(weight_name.split("/")[1:]) saved_weight_names_values[name] = np.asarray(g[weight_name]) + saved_weight_names_set.add(name) + for symbolic_weight in symbolic_weights: - splited_layers = symbolic_weight.name.split("/")[1:] - symbolic_weight_name = "/".join(splited_layers) + symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) + saved_weight_value = None if symbolic_weight_name in saved_weight_names_values: saved_weight_value = saved_weight_names_values[symbolic_weight_name] + if saved_weight_value is not None: if K.int_shape(symbolic_weight) != saved_weight_value.shape: try: array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) @@ -318,8 +305,15 @@ def load_tf_weights(model, resolved_archive_file): weight_value_tuples.append((symbolic_weight, array)) + symbolic_weights_names.add(symbolic_weight_name) + K.batch_set_value(weight_value_tuples) + missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) + unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) + + return missing_layers, unexpected_layers + class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): r""" @@ -728,7 +722,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: - load_tf_weights(model, resolved_archive_file) + missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file) except OSError: raise OSError( "Unable to load weights from h5 file. " @@ -737,8 +731,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model(model.dummy_inputs, training=False) # Make sure restore ops are run - missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file) - if cls.authorized_missing_keys is not None: for pat in cls.authorized_missing_keys: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] From 3d842c325b267950efdbc2f872187f6887e2288a Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 12 Nov 2020 11:20:51 +0100 Subject: [PATCH 02/11] apply style --- src/transformers/modeling_tf_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 3ec97bc9f4c2..19f694dd18cc 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -238,12 +238,14 @@ def compute_loss(self, labels, logits): def load_tf_weights(model, resolved_archive_file): """ - Detect missing and unexpected layers. + Detect missing and unexpected layers + Args: model (:obj:`tf.keras.models.Model`): The model to load the weights into. resolved_archive_file (:obj:`str`): - The location of the H5 file. + The location of the H5 file + Returns: Two lists, one for the missing layers, and another one for the unexpected layers. """ From b324000ab7bdd7a4203946825f4b89fad02cf550 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Thu, 12 Nov 2020 11:59:05 +0100 Subject: [PATCH 03/11] Better naming --- src/transformers/modeling_tf_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 19f694dd18cc..6b68430b5205 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1028,18 +1028,18 @@ def call(self, inputs, cls_index=None, training=False): return output -def shape_list(x: tf.Tensor) -> List[int]: +def shape_list(tensor: tf.Tensor) -> List[int]: """ Deal with dynamic shape in tensorflow cleanly. Args: - x (:obj:`tf.Tensor`): The tensor we want the shape of. + tensor (:obj:`tf.Tensor`): The tensor we want the shape of. Returns: :obj:`List[int]`: The shape of the tensor as a list. """ - static = x.shape.as_list() - dynamic = tf.shape(x) + static = tensor.shape.as_list() + dynamic = tf.shape(tensor) return [dynamic[i] if s is None else s for i, s in enumerate(static)] From 8a89d8c97e634cbdfa3dd5491a7b68e017aed022 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Sat, 14 Nov 2020 14:14:49 +0100 Subject: [PATCH 04/11] Largely comment the loading method --- src/transformers/modeling_tf_utils.py | 48 ++++++++++++++++++++------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 6b68430b5205..487592372753 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -252,51 +252,74 @@ def load_tf_weights(model, resolved_archive_file): missing_layers = [] unexpected_layers = [] + # Read the H5 file with h5py.File(resolved_archive_file, "r") as f: + # Retrieve the name of each layer from the H5 file saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layers_name_value = {} - for layer in model.layers: - name = layer.name - model_layers_name_value[name] = layer + # Retrieve the name of each layer from the instanciated model + # make it a dict that looks like {"layer_name": Layer object} + model_layers_name_value = {layer.name: layer for layer in model.layers} + # Create a set of unique names of the layers that come from the instanciated model model_layers_name = set(model_layers_name_value.keys()) - renamed_saved_h5_model_layers_names = set() - for layer_name in saved_h5_model_layers_name: - name = layer_name - - renamed_saved_h5_model_layers_names.add(name) + # Find the missing layers from the high level list of layers + missing_layers = list(model_layers_name - saved_h5_model_layers_name) - missing_layers = list(model_layers_name - renamed_saved_h5_model_layers_names) - unexpected_layers = list(renamed_saved_h5_model_layers_names - model_layers_name) + # Find the unexpected layers from the high level list of layers + unexpected_layers = list(saved_h5_model_layers_name - model_layers_name) saved_weight_names_set = set() symbolic_weights_names = set() weight_value_tuples = [] + # Compute missing and unexpected sub layers + # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] for layer_name in saved_h5_model_layers_name: + # if layer_name from the H5 file belongs to the layers from the instanciated model if layer_name in model_layers_name: + # Get layer_name from the H5 file g = f[layer_name] + # Get all the weights that are attach to layer_name saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") + # Get the layer object from the layer_name in the dict that represents the instanciated model layer = model_layers_name_value[layer_name] + # Get all the weights as a list from the layer object symbolic_weights = layer.trainable_weights + layer.non_trainable_weights saved_weight_names_values = {} + # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} + # And a set with only the names for weight_name in saved_weight_names: + # TF names always start with the model name so we ignore it name = "/".join(weight_name.split("/")[1:]) saved_weight_names_values[name] = np.asarray(g[weight_name]) + # Add the updated name to the final list for computing missing/unexpected values saved_weight_names_set.add(name) + # Loop over each weights from the instanciated model and compare with the weights from the H5 file for symbolic_weight in symbolic_weights: + # TF names always start with the model name so we ignore it symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) saved_weight_value = None + # Add the updated name to the final list for computing missing/unexpected values + symbolic_weights_names.add(symbolic_weight_name) + + # here we check if the current weight is among the weights from the H5 file + # If yes, get the weight_value of the corresponding weight from the H5 file + # If not, keep the value to None if symbolic_weight_name in saved_weight_names_values: saved_weight_value = saved_weight_names_values[symbolic_weight_name] + # If the current weight is found if saved_weight_value is not None: + # Check if the shape of the current weight and the one from the H5 file are different if K.int_shape(symbolic_weight) != saved_weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue try: array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) except AssertionError as e: @@ -305,12 +328,13 @@ def load_tf_weights(model, resolved_archive_file): else: array = saved_weight_value + # We create the tuple that will be loaded and add it to the final list weight_value_tuples.append((symbolic_weight, array)) - symbolic_weights_names.add(symbolic_weight_name) - + # Load all the weights K.batch_set_value(weight_value_tuples) + # Compute the missing and unexpected layers missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) From e327cfe5d81818f4b40cc491a4644d429d131156 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Sat, 14 Nov 2020 14:20:11 +0100 Subject: [PATCH 05/11] Apply style --- src/transformers/modeling_tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 487592372753..b19c4d9be98a 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -307,7 +307,7 @@ def load_tf_weights(model, resolved_archive_file): # Add the updated name to the final list for computing missing/unexpected values symbolic_weights_names.add(symbolic_weight_name) - + # here we check if the current weight is among the weights from the H5 file # If yes, get the weight_value of the corresponding weight from the H5 file # If not, keep the value to None From afa0411b85497b850ffa4e89f771a08a47c9ff56 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Mon, 16 Nov 2020 20:01:06 +0100 Subject: [PATCH 06/11] Address Patrick's comments --- src/transformers/modeling_tf_utils.py | 28 +++++++++++++-------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index b19c4d9be98a..3d300f443f28 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -244,7 +244,7 @@ def load_tf_weights(model, resolved_archive_file): model (:obj:`tf.keras.models.Model`): The model to load the weights into. resolved_archive_file (:obj:`str`): - The location of the H5 file + The location of the H5 file. Returns: Two lists, one for the missing layers, and another one for the unexpected layers. @@ -258,11 +258,11 @@ def load_tf_weights(model, resolved_archive_file): saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layers_name_value = {} - # Retrieve the name of each layer from the instanciated model + # Retrieve the name of each layer from the instantiated model # make it a dict that looks like {"layer_name": Layer object} model_layers_name_value = {layer.name: layer for layer in model.layers} - # Create a set of unique names of the layers that come from the instanciated model + # Create a set of unique names of the layers that come from the instantiated model model_layers_name = set(model_layers_name_value.keys()) # Find the missing layers from the high level list of layers @@ -279,22 +279,22 @@ def load_tf_weights(model, resolved_archive_file): for layer_name in saved_h5_model_layers_name: # if layer_name from the H5 file belongs to the layers from the instanciated model if layer_name in model_layers_name: - # Get layer_name from the H5 file - g = f[layer_name] + # Get the H5 layer object from its name + h5_layer_object = f[layer_name] # Get all the weights that are attach to layer_name - saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") + saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names") # Get the layer object from the layer_name in the dict that represents the instanciated model layer = model_layers_name_value[layer_name] # Get all the weights as a list from the layer object symbolic_weights = layer.trainable_weights + layer.non_trainable_weights - saved_weight_names_values = {} + saved_weights = {} # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} # And a set with only the names for weight_name in saved_weight_names: # TF names always start with the model name so we ignore it name = "/".join(weight_name.split("/")[1:]) - saved_weight_names_values[name] = np.asarray(g[weight_name]) + saved_weights[name] = np.asarray(h5_layer_object[weight_name]) # Add the updated name to the final list for computing missing/unexpected values saved_weight_names_set.add(name) @@ -303,16 +303,14 @@ def load_tf_weights(model, resolved_archive_file): for symbolic_weight in symbolic_weights: # TF names always start with the model name so we ignore it symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) - saved_weight_value = None - - # Add the updated name to the final list for computing missing/unexpected values - symbolic_weights_names.add(symbolic_weight_name) # here we check if the current weight is among the weights from the H5 file # If yes, get the weight_value of the corresponding weight from the H5 file - # If not, keep the value to None - if symbolic_weight_name in saved_weight_names_values: - saved_weight_value = saved_weight_names_values[symbolic_weight_name] + # If not, make the value to None + saved_weight_value = saved_weights.get(symbolic_weight_name, None) + + # Add the updated name to the final list for computing missing/unexpected values + symbolic_weights_names.add(symbolic_weight_name) # If the current weight is found if saved_weight_value is not None: From 49669ede508ced6faf086280dbc3479d7628f4d9 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 17 Nov 2020 10:02:49 +0100 Subject: [PATCH 07/11] Remove useless line of code --- src/transformers/modeling_tf_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 3d300f443f28..5eec0fae42d5 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -256,7 +256,6 @@ def load_tf_weights(model, resolved_archive_file): with h5py.File(resolved_archive_file, "r") as f: # Retrieve the name of each layer from the H5 file saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) - model_layers_name_value = {} # Retrieve the name of each layer from the instantiated model # make it a dict that looks like {"layer_name": Layer object} From 656703761a97cc4f81982cf4164d7173e6f2a9fe Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Tue, 17 Nov 2020 10:05:15 +0100 Subject: [PATCH 08/11] Update Docstring --- src/transformers/modeling_tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 5eec0fae42d5..322409bb5834 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -238,7 +238,7 @@ def compute_loss(self, labels, logits): def load_tf_weights(model, resolved_archive_file): """ - Detect missing and unexpected layers + Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. Args: model (:obj:`tf.keras.models.Model`): From 08f998c9ded783bf6ba05f1cf6aac231df3d13cd Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 18 Nov 2020 09:40:07 +0100 Subject: [PATCH 09/11] Address Sylvain's and Lysandre's comments --- src/transformers/modeling_tf_pytorch_utils.py | 2 +- src/transformers/modeling_tf_utils.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 3eaea4b0bd2c..adcc19c61be2 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -280,7 +280,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs if tf_inputs is not None: tf_model(tf_inputs, training=False) # Make sure model is built - _, _ = load_tf_weights(tf_model, tf_checkpoint_path) + load_tf_weights(tf_model, tf_checkpoint_path) return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 322409bb5834..1054dd26880c 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -257,33 +257,24 @@ def load_tf_weights(model, resolved_archive_file): # Retrieve the name of each layer from the H5 file saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) - # Retrieve the name of each layer from the instantiated model - # make it a dict that looks like {"layer_name": Layer object} - model_layers_name_value = {layer.name: layer for layer in model.layers} - - # Create a set of unique names of the layers that come from the instantiated model - model_layers_name = set(model_layers_name_value.keys()) - # Find the missing layers from the high level list of layers - missing_layers = list(model_layers_name - saved_h5_model_layers_name) + missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name) # Find the unexpected layers from the high level list of layers - unexpected_layers = list(saved_h5_model_layers_name - model_layers_name) + unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers])) saved_weight_names_set = set() symbolic_weights_names = set() weight_value_tuples = [] # Compute missing and unexpected sub layers # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] - for layer_name in saved_h5_model_layers_name: + for layer in model.layers: # if layer_name from the H5 file belongs to the layers from the instanciated model - if layer_name in model_layers_name: + if layer.name in saved_h5_model_layers_name: # Get the H5 layer object from its name - h5_layer_object = f[layer_name] + h5_layer_object = f[layer.name] # Get all the weights that are attach to layer_name saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names") - # Get the layer object from the layer_name in the dict that represents the instanciated model - layer = model_layers_name_value[layer_name] # Get all the weights as a list from the layer object symbolic_weights = layer.trainable_weights + layer.non_trainable_weights saved_weights = {} From ac6785fd6cd1abc9a3f9276fd56f8f14dbbc2cf7 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 18 Nov 2020 09:48:03 +0100 Subject: [PATCH 10/11] Simplify the names computation --- src/transformers/modeling_tf_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 1054dd26880c..9e1600306604 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -273,15 +273,13 @@ def load_tf_weights(model, resolved_archive_file): if layer.name in saved_h5_model_layers_name: # Get the H5 layer object from its name h5_layer_object = f[layer.name] - # Get all the weights that are attach to layer_name - saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names") # Get all the weights as a list from the layer object symbolic_weights = layer.trainable_weights + layer.non_trainable_weights saved_weights = {} # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} # And a set with only the names - for weight_name in saved_weight_names: + for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): # TF names always start with the model name so we ignore it name = "/".join(weight_name.split("/")[1:]) saved_weights[name] = np.asarray(h5_layer_object[weight_name]) From 315249e4111886ea97d812d5cbcf7a87f2210310 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 18 Nov 2020 11:13:20 +0100 Subject: [PATCH 11/11] Typos --- src/transformers/modeling_tf_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 9e1600306604..324f350f85db 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -269,7 +269,7 @@ def load_tf_weights(model, resolved_archive_file): # Compute missing and unexpected sub layers # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] for layer in model.layers: - # if layer_name from the H5 file belongs to the layers from the instanciated model + # if layer_name from the H5 file belongs to the layers from the instantiated model if layer.name in saved_h5_model_layers_name: # Get the H5 layer object from its name h5_layer_object = f[layer.name] @@ -287,7 +287,7 @@ def load_tf_weights(model, resolved_archive_file): # Add the updated name to the final list for computing missing/unexpected values saved_weight_names_set.add(name) - # Loop over each weights from the instanciated model and compare with the weights from the H5 file + # Loop over each weights from the instantiated model and compare with the weights from the H5 file for symbolic_weight in symbolic_weights: # TF names always start with the model name so we ignore it symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])