Skip to content

Commit 349f1c8

Browse files
Rewrite TensorFlow train_step and test_step (huggingface#17057)
* Initial commit * Better label renaming * Remove breakpoint before pushing (this is your job) * Test a lot more in the Keras fit() test * make fixup * Clarify the case where we flatten y dicts into tensors * Clarify the case where we flatten y dicts into tensors * Extract label name remapping to a method
1 parent 651e48e commit 349f1c8

File tree

2 files changed

+149
-48
lines changed

2 files changed

+149
-48
lines changed

src/transformers/modeling_tf_utils.py

Lines changed: 122 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
723723
main_input_name = "input_ids"
724724
_auto_class = None
725725
_using_dummy_loss = None
726+
_label_to_output_map = None
726727

727728
# a list of re pattern of tensor names to ignore from the model when loading the model weights
728729
# (and avoid unnecessary warnings).
@@ -907,24 +908,18 @@ def compile(
907908
function themselves.
908909
"""
909910
if loss == "passthrough":
910-
if metrics is not None:
911-
raise ValueError(
912-
"Passing metrics as a dict is not supported when using the internal loss! "
913-
"Please either compile the model with a loss, or remove the metrics argument. "
914-
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
915-
"loss."
916-
)
917911
logger.warning(
918912
"No loss specified in compile() - the model's internal loss computation will be used as the "
919913
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
920-
"To disable this behaviour, please pass a loss argument, or explicitly pass "
914+
"To disable this behaviour please pass a loss argument, or explicitly pass "
921915
"`loss=None` if you do not want your model to compute a loss."
922916
)
923917
loss = dummy_loss
924918
self._using_dummy_loss = True
925919
else:
926920
self._using_dummy_loss = False
927921
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
922+
# This argument got renamed, we need to support both versions
928923
if "steps_per_execution" in parent_args:
929924
super().compile(
930925
optimizer=optimizer,
@@ -962,27 +957,42 @@ def compute_loss(self, *args, **kwargs):
962957
)
963958
return self.hf_compute_loss(*args, **kwargs)
964959

960+
def get_label_to_output_name_mapping(self):
961+
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
962+
if self._label_to_output_map is not None:
963+
return self._label_to_output_map
964+
elif "start_positions" in arg_names:
965+
return {"start_positions": "start_logits", "end_positions": "end_logits"}
966+
elif "sentence_order_label" in arg_names:
967+
return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
968+
elif "next_sentence_label" in arg_names:
969+
return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
970+
elif "mc_labels" in arg_names:
971+
return {"labels": "logits", "mc_labels": "mc_logits"}
972+
else:
973+
return dict()
974+
965975
def train_step(self, data):
966976
"""
967-
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
968-
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
969-
970-
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
971-
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
972-
as keys in the input dictionary, or as normal Keras labels.
977+
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
978+
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
979+
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
980+
that they are available to the model during the forward pass.
973981
"""
974982

975-
# These are the only transformations `Model.fit` applies to user-input
976-
# data when a `tf.data.Dataset` is provided.
983+
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
984+
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
985+
label_kwargs = find_labels(self.__class__)
986+
label_to_output = self.get_label_to_output_name_mapping()
987+
output_to_label = {val: key for key, val in label_to_output.items()}
977988
if not self._using_dummy_loss:
978989
data = data_adapter.expand_1d(data)
979990
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
980991

981992
# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
982993
# if those keys are not already present in the input dict
983994
if self._using_dummy_loss and y is not None:
984-
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
985-
label_kwargs = find_labels(self.__class__)
995+
986996
# If y is a tensor and the model only has one label-like input, map y to that input
987997
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
988998
if isinstance(x, tf.Tensor):
@@ -997,22 +1007,59 @@ def train_step(self, data):
9971007
for key, val in y.items():
9981008
if key in arg_names and key not in x:
9991009
x[key] = val
1010+
elif output_to_label.get(key, None) in arg_names and key not in x:
1011+
x[output_to_label[key]] = val
1012+
if y is None:
1013+
y = {key: val for key, val in x.items() if key in label_kwargs}
1014+
if not y and not self._using_dummy_loss:
1015+
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
1016+
1017+
if isinstance(y, dict):
1018+
# Rename labels at this point to match output heads
1019+
y = {label_to_output.get(key, key): val for key, val in y.items()}
10001020

10011021
# Run forward pass.
10021022
with tf.GradientTape() as tape:
10031023
y_pred = self(x, training=True)
10041024
if self._using_dummy_loss:
10051025
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
10061026
else:
1027+
loss = None
1028+
1029+
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
1030+
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
1031+
if isinstance(y, dict) and len(y) == 1:
1032+
if list(y.keys())[0] in y_pred.keys():
1033+
y_pred = y_pred[list(y.keys())[0]]
1034+
elif list(y_pred.keys())[0] == "loss":
1035+
y_pred = y_pred[1]
1036+
else:
1037+
y_pred = y_pred[0]
1038+
_, y = y.popitem()
1039+
elif isinstance(y, dict):
1040+
# If the labels are a dict, match keys from the output by name
1041+
y_pred = {key: val for key, val in y_pred.items() if key in y}
1042+
elif isinstance(y, tuple) or isinstance(y, list):
1043+
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
1044+
if list(y_pred.keys())[0] == "loss":
1045+
y_pred = y_pred.to_tuple()[1:]
1046+
else:
1047+
y_pred = y_pred.to_tuple()
1048+
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
1049+
else:
1050+
# If the labels are a single tensor, match them to the first non-loss tensor in the output
1051+
if list(y_pred.keys())[0] == "loss":
1052+
y_pred = y_pred[1]
1053+
else:
1054+
y_pred = y_pred[0]
1055+
1056+
if loss is None:
10071057
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
1058+
10081059
# Run backwards pass.
10091060
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
10101061

1011-
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
1012-
if self._using_dummy_loss:
1013-
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
1014-
else:
1015-
self.compiled_metrics.update_state(y, y_pred, sample_weight)
1062+
self.compiled_metrics.update_state(y, y_pred, sample_weight)
10161063
# Collect metrics to return
10171064
return_metrics = {}
10181065
for metric in self.metrics:
@@ -1021,23 +1068,20 @@ def train_step(self, data):
10211068
return_metrics.update(result)
10221069
else:
10231070
return_metrics[metric.name] = result
1024-
# These next two lines are also not in the base method - they correct the displayed metrics
1025-
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
1026-
if "loss" in return_metrics and "loss_loss" in return_metrics:
1027-
del return_metrics["loss_loss"]
10281071
return return_metrics
10291072

10301073
def test_step(self, data):
10311074
"""
1032-
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
1033-
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
1034-
1035-
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
1036-
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
1037-
as keys in the input dictionary, or as normal Keras labels.
1075+
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
1076+
and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
1077+
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
1078+
that they are available to the model during the forward pass.
10381079
"""
1039-
# These are the only transformations `Model.fit` applies to user-input
1040-
# data when a `tf.data.Dataset` is provided.
1080+
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
1081+
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
1082+
label_kwargs = find_labels(self.__class__)
1083+
label_to_output = self.get_label_to_output_name_mapping()
1084+
output_to_label = {val: key for key, val in label_to_output.items()}
10411085
if not self._using_dummy_loss:
10421086
data = data_adapter.expand_1d(data)
10431087
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
@@ -1046,7 +1090,6 @@ def test_step(self, data):
10461090
# if those keys are not already present in the input dict
10471091
if self._using_dummy_loss and y is not None:
10481092
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
1049-
label_kwargs = find_labels(self.__class__)
10501093
# If y is a tensor and the model only has one label-like input, map y to that input
10511094
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
10521095
if isinstance(x, tf.Tensor):
@@ -1061,19 +1104,55 @@ def test_step(self, data):
10611104
for key, val in y.items():
10621105
if key in arg_names and key not in x:
10631106
x[key] = val
1107+
elif output_to_label.get(key, None) in arg_names and key not in x:
1108+
x[output_to_label[key]] = val
1109+
if y is None:
1110+
y = {key: val for key, val in x.items() if key in label_kwargs}
1111+
if not y and not self._using_dummy_loss:
1112+
raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
1113+
1114+
if isinstance(y, dict):
1115+
# Rename labels at this point to match output heads
1116+
y = {label_to_output.get(key, key): val for key, val in y.items()}
10641117

10651118
# Run forward pass.
10661119
y_pred = self(x, training=False)
10671120
if self._using_dummy_loss:
1068-
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
1121+
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
10691122
else:
1070-
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
1071-
1072-
# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
1073-
if self._using_dummy_loss:
1074-
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
1123+
loss = None
1124+
1125+
# This next block matches outputs to label keys. Tensorflow's standard method for doing this
1126+
# can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
1127+
if isinstance(y, dict) and len(y) == 1:
1128+
if list(y.keys())[0] in y_pred.keys():
1129+
y_pred = y_pred[list(y.keys())[0]]
1130+
elif list(y_pred.keys())[0] == "loss":
1131+
y_pred = y_pred[1]
1132+
else:
1133+
y_pred = y_pred[0]
1134+
_, y = y.popitem()
1135+
elif isinstance(y, dict):
1136+
# If the labels are a dict, match keys from the output by name
1137+
y_pred = {key: val for key, val in y_pred.items() if key in y}
1138+
elif isinstance(y, tuple) or isinstance(y, list):
1139+
# If the labels are a tuple/list, match keys to the output by order, skipping the loss.
1140+
if list(y_pred.keys())[0] == "loss":
1141+
y_pred = y_pred.to_tuple()[1:]
1142+
else:
1143+
y_pred = y_pred.to_tuple()
1144+
y_pred = y_pred[: len(y)] # Remove unused fields in case those cause problems
10751145
else:
1076-
self.compiled_metrics.update_state(y, y_pred, sample_weight)
1146+
# If the labels are a single tensor, match them to the first non-loss tensor in the output
1147+
if list(y_pred.keys())[0] == "loss":
1148+
y_pred = y_pred[1]
1149+
else:
1150+
y_pred = y_pred[0]
1151+
1152+
if loss is None:
1153+
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
1154+
1155+
self.compiled_metrics.update_state(y, y_pred, sample_weight)
10771156
# Collect metrics to return
10781157
return_metrics = {}
10791158
for metric in self.metrics:
@@ -1082,10 +1161,6 @@ def test_step(self, data):
10821161
return_metrics.update(result)
10831162
else:
10841163
return_metrics[metric.name] = result
1085-
# These next two lines are also not in the base method - they correct the displayed metrics
1086-
# when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
1087-
if "loss" in return_metrics and "loss_loss" in return_metrics:
1088-
del return_metrics["loss_loss"]
10891164
return return_metrics
10901165

10911166
def create_model_card(

tests/test_modeling_tf_common.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,25 @@ def test_keras_fit(self):
13551355
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
13561356
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
13571357
self.assertGreater(len(inputs_minus_labels), 0)
1358-
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
1358+
accuracy_classes = [
1359+
"ForPreTraining",
1360+
"ForCausalLM",
1361+
"ForMaskedLM",
1362+
"ForQuestionAnswering",
1363+
"ForMultipleChoice",
1364+
"ForSequenceClassification",
1365+
"ForTokenClassification",
1366+
"ForNextSentencePrediction",
1367+
"LMHeadModel",
1368+
]
1369+
for accuracy_class in accuracy_classes:
1370+
if model.__class__.__name__.endswith(accuracy_class):
1371+
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
1372+
break
1373+
else:
1374+
metrics = []
1375+
1376+
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True, metrics=metrics)
13591377
# Make sure the model fits without crashing regardless of where we pass the labels
13601378
history1 = model.fit(
13611379
prepared_for_class,
@@ -1365,6 +1383,7 @@ def test_keras_fit(self):
13651383
shuffle=False,
13661384
)
13671385
val_loss1 = history1.history["val_loss"][0]
1386+
accuracy1 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
13681387
history2 = model.fit(
13691388
inputs_minus_labels,
13701389
labels,
@@ -1374,7 +1393,14 @@ def test_keras_fit(self):
13741393
shuffle=False,
13751394
)
13761395
val_loss2 = history2.history["val_loss"][0]
1396+
accuracy2 = {key: val[0] for key, val in history1.history.items() if key.endswith("accuracy")}
13771397
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
1398+
self.assertEqual(history1.history.keys(), history2.history.keys())
1399+
for key in history1.history.keys():
1400+
if not key.startswith("val_"):
1401+
self.assertTrue("val_" + key in history1.history.keys(), "Outputs differ in train/test step!")
1402+
if metrics:
1403+
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
13781404

13791405
def test_int64_inputs(self):
13801406
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

0 commit comments

Comments
 (0)