@@ -723,6 +723,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
723723 main_input_name = "input_ids"
724724 _auto_class = None
725725 _using_dummy_loss = None
726+ _label_to_output_map = None
726727
727728 # a list of re pattern of tensor names to ignore from the model when loading the model weights
728729 # (and avoid unnecessary warnings).
@@ -907,24 +908,18 @@ def compile(
907908 function themselves.
908909 """
909910 if loss == "passthrough" :
910- if metrics is not None :
911- raise ValueError (
912- "Passing metrics as a dict is not supported when using the internal loss! "
913- "Please either compile the model with a loss, or remove the metrics argument. "
914- "Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
915- "loss."
916- )
917911 logger .warning (
918912 "No loss specified in compile() - the model's internal loss computation will be used as the "
919913 "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
920- "To disable this behaviour, please pass a loss argument, or explicitly pass "
914+ "To disable this behaviour please pass a loss argument, or explicitly pass "
921915 "`loss=None` if you do not want your model to compute a loss."
922916 )
923917 loss = dummy_loss
924918 self ._using_dummy_loss = True
925919 else :
926920 self ._using_dummy_loss = False
927921 parent_args = list (inspect .signature (tf .keras .Model .compile ).parameters .keys ())
922+ # This argument got renamed, we need to support both versions
928923 if "steps_per_execution" in parent_args :
929924 super ().compile (
930925 optimizer = optimizer ,
@@ -962,27 +957,42 @@ def compute_loss(self, *args, **kwargs):
962957 )
963958 return self .hf_compute_loss (* args , ** kwargs )
964959
960+ def get_label_to_output_name_mapping (self ):
961+ arg_names = list (dict (inspect .signature (self .call ).parameters ).keys ())
962+ if self ._label_to_output_map is not None :
963+ return self ._label_to_output_map
964+ elif "start_positions" in arg_names :
965+ return {"start_positions" : "start_logits" , "end_positions" : "end_logits" }
966+ elif "sentence_order_label" in arg_names :
967+ return {"labels" : "prediction_logits" , "sentence_order_label" : "sop_logits" }
968+ elif "next_sentence_label" in arg_names :
969+ return {"labels" : "prediction_logits" , "next_sentence_label" : "seq_relationship_logits" }
970+ elif "mc_labels" in arg_names :
971+ return {"labels" : "logits" , "mc_labels" : "mc_logits" }
972+ else :
973+ return dict ()
974+
965975 def train_step (self , data ):
966976 """
967- A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
968- a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
969-
970- When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
971- loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
972- as keys in the input dictionary, or as normal Keras labels.
977+ A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
978+ and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
979+ labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
980+ that they are available to the model during the forward pass.
973981 """
974982
975- # These are the only transformations `Model.fit` applies to user-input
976- # data when a `tf.data.Dataset` is provided.
983+ # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
984+ arg_names = list (dict (inspect .signature (self .call ).parameters ).keys ())
985+ label_kwargs = find_labels (self .__class__ )
986+ label_to_output = self .get_label_to_output_name_mapping ()
987+ output_to_label = {val : key for key , val in label_to_output .items ()}
977988 if not self ._using_dummy_loss :
978989 data = data_adapter .expand_1d (data )
979990 x , y , sample_weight = data_adapter .unpack_x_y_sample_weight (data )
980991
981992 # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
982993 # if those keys are not already present in the input dict
983994 if self ._using_dummy_loss and y is not None :
984- arg_names = list (dict (inspect .signature (self .call ).parameters ).keys ())
985- label_kwargs = find_labels (self .__class__ )
995+
986996 # If y is a tensor and the model only has one label-like input, map y to that input
987997 if len (label_kwargs ) == 1 and isinstance (y , tf .Tensor ):
988998 if isinstance (x , tf .Tensor ):
@@ -997,22 +1007,59 @@ def train_step(self, data):
9971007 for key , val in y .items ():
9981008 if key in arg_names and key not in x :
9991009 x [key ] = val
1010+ elif output_to_label .get (key , None ) in arg_names and key not in x :
1011+ x [output_to_label [key ]] = val
1012+ if y is None :
1013+ y = {key : val for key , val in x .items () if key in label_kwargs }
1014+ if not y and not self ._using_dummy_loss :
1015+ raise ValueError ("Could not find label column(s) in input dict and no separate labels were provided!" )
1016+
1017+ if isinstance (y , dict ):
1018+ # Rename labels at this point to match output heads
1019+ y = {label_to_output .get (key , key ): val for key , val in y .items ()}
10001020
10011021 # Run forward pass.
10021022 with tf .GradientTape () as tape :
10031023 y_pred = self (x , training = True )
10041024 if self ._using_dummy_loss :
10051025 loss = self .compiled_loss (y_pred .loss , y_pred .loss , sample_weight , regularization_losses = self .losses )
10061026 else :
1027+ loss = None
1028+
1029+ # This next block matches outputs to label keys. Tensorflow's standard method for doing this
1030+ # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
1031+ if isinstance (y , dict ) and len (y ) == 1 :
1032+ if list (y .keys ())[0 ] in y_pred .keys ():
1033+ y_pred = y_pred [list (y .keys ())[0 ]]
1034+ elif list (y_pred .keys ())[0 ] == "loss" :
1035+ y_pred = y_pred [1 ]
1036+ else :
1037+ y_pred = y_pred [0 ]
1038+ _ , y = y .popitem ()
1039+ elif isinstance (y , dict ):
1040+ # If the labels are a dict, match keys from the output by name
1041+ y_pred = {key : val for key , val in y_pred .items () if key in y }
1042+ elif isinstance (y , tuple ) or isinstance (y , list ):
1043+ # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
1044+ if list (y_pred .keys ())[0 ] == "loss" :
1045+ y_pred = y_pred .to_tuple ()[1 :]
1046+ else :
1047+ y_pred = y_pred .to_tuple ()
1048+ y_pred = y_pred [: len (y )] # Remove unused fields in case those cause problems
1049+ else :
1050+ # If the labels are a single tensor, match them to the first non-loss tensor in the output
1051+ if list (y_pred .keys ())[0 ] == "loss" :
1052+ y_pred = y_pred [1 ]
1053+ else :
1054+ y_pred = y_pred [0 ]
1055+
1056+ if loss is None :
10071057 loss = self .compiled_loss (y , y_pred , sample_weight , regularization_losses = self .losses )
1058+
10081059 # Run backwards pass.
10091060 self .optimizer .minimize (loss , self .trainable_variables , tape = tape )
10101061
1011- # When using the dummy_loss we know metrics are not present, so we can skip a lot of this
1012- if self ._using_dummy_loss :
1013- self .compiled_metrics .update_state (y_pred .loss , y_pred .loss , sample_weight )
1014- else :
1015- self .compiled_metrics .update_state (y , y_pred , sample_weight )
1062+ self .compiled_metrics .update_state (y , y_pred , sample_weight )
10161063 # Collect metrics to return
10171064 return_metrics = {}
10181065 for metric in self .metrics :
@@ -1021,23 +1068,20 @@ def train_step(self, data):
10211068 return_metrics .update (result )
10221069 else :
10231070 return_metrics [metric .name ] = result
1024- # These next two lines are also not in the base method - they correct the displayed metrics
1025- # when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
1026- if "loss" in return_metrics and "loss_loss" in return_metrics :
1027- del return_metrics ["loss_loss" ]
10281071 return return_metrics
10291072
10301073 def test_step (self , data ):
10311074 """
1032- A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
1033- user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
1034-
1035- When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
1036- loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
1037- as keys in the input dictionary, or as normal Keras labels.
1075+ A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
1076+ and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
1077+ labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
1078+ that they are available to the model during the forward pass.
10381079 """
1039- # These are the only transformations `Model.fit` applies to user-input
1040- # data when a `tf.data.Dataset` is provided.
1080+ # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
1081+ arg_names = list (dict (inspect .signature (self .call ).parameters ).keys ())
1082+ label_kwargs = find_labels (self .__class__ )
1083+ label_to_output = self .get_label_to_output_name_mapping ()
1084+ output_to_label = {val : key for key , val in label_to_output .items ()}
10411085 if not self ._using_dummy_loss :
10421086 data = data_adapter .expand_1d (data )
10431087 x , y , sample_weight = data_adapter .unpack_x_y_sample_weight (data )
@@ -1046,7 +1090,6 @@ def test_step(self, data):
10461090 # if those keys are not already present in the input dict
10471091 if self ._using_dummy_loss and y is not None :
10481092 arg_names = list (dict (inspect .signature (self .call ).parameters ).keys ())
1049- label_kwargs = find_labels (self .__class__ )
10501093 # If y is a tensor and the model only has one label-like input, map y to that input
10511094 if len (label_kwargs ) == 1 and isinstance (y , tf .Tensor ):
10521095 if isinstance (x , tf .Tensor ):
@@ -1061,19 +1104,55 @@ def test_step(self, data):
10611104 for key , val in y .items ():
10621105 if key in arg_names and key not in x :
10631106 x [key ] = val
1107+ elif output_to_label .get (key , None ) in arg_names and key not in x :
1108+ x [output_to_label [key ]] = val
1109+ if y is None :
1110+ y = {key : val for key , val in x .items () if key in label_kwargs }
1111+ if not y and not self ._using_dummy_loss :
1112+ raise ValueError ("Could not find label column(s) in input dict and no separate labels were provided!" )
1113+
1114+ if isinstance (y , dict ):
1115+ # Rename labels at this point to match output heads
1116+ y = {label_to_output .get (key , key ): val for key , val in y .items ()}
10641117
10651118 # Run forward pass.
10661119 y_pred = self (x , training = False )
10671120 if self ._using_dummy_loss :
1068- self .compiled_loss (y_pred .loss , y_pred .loss , sample_weight , regularization_losses = self .losses )
1121+ loss = self .compiled_loss (y_pred .loss , y_pred .loss , sample_weight , regularization_losses = self .losses )
10691122 else :
1070- self .compiled_loss (y , y_pred , sample_weight , regularization_losses = self .losses )
1071-
1072- # When using the dummy_loss we know metrics are not present, so we can skip a lot of this
1073- if self ._using_dummy_loss :
1074- self .compiled_metrics .update_state (y_pred .loss , y_pred .loss , sample_weight )
1123+ loss = None
1124+
1125+ # This next block matches outputs to label keys. Tensorflow's standard method for doing this
1126+ # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
1127+ if isinstance (y , dict ) and len (y ) == 1 :
1128+ if list (y .keys ())[0 ] in y_pred .keys ():
1129+ y_pred = y_pred [list (y .keys ())[0 ]]
1130+ elif list (y_pred .keys ())[0 ] == "loss" :
1131+ y_pred = y_pred [1 ]
1132+ else :
1133+ y_pred = y_pred [0 ]
1134+ _ , y = y .popitem ()
1135+ elif isinstance (y , dict ):
1136+ # If the labels are a dict, match keys from the output by name
1137+ y_pred = {key : val for key , val in y_pred .items () if key in y }
1138+ elif isinstance (y , tuple ) or isinstance (y , list ):
1139+ # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
1140+ if list (y_pred .keys ())[0 ] == "loss" :
1141+ y_pred = y_pred .to_tuple ()[1 :]
1142+ else :
1143+ y_pred = y_pred .to_tuple ()
1144+ y_pred = y_pred [: len (y )] # Remove unused fields in case those cause problems
10751145 else :
1076- self .compiled_metrics .update_state (y , y_pred , sample_weight )
1146+ # If the labels are a single tensor, match them to the first non-loss tensor in the output
1147+ if list (y_pred .keys ())[0 ] == "loss" :
1148+ y_pred = y_pred [1 ]
1149+ else :
1150+ y_pred = y_pred [0 ]
1151+
1152+ if loss is None :
1153+ loss = self .compiled_loss (y , y_pred , sample_weight , regularization_losses = self .losses )
1154+
1155+ self .compiled_metrics .update_state (y , y_pred , sample_weight )
10771156 # Collect metrics to return
10781157 return_metrics = {}
10791158 for metric in self .metrics :
@@ -1082,10 +1161,6 @@ def test_step(self, data):
10821161 return_metrics .update (result )
10831162 else :
10841163 return_metrics [metric .name ] = result
1085- # These next two lines are also not in the base method - they correct the displayed metrics
1086- # when we're using a dummy loss, to avoid a bogus "loss_loss" value being shown.
1087- if "loss" in return_metrics and "loss_loss" in return_metrics :
1088- del return_metrics ["loss_loss" ]
10891164 return return_metrics
10901165
10911166 def create_model_card (
0 commit comments