@@ -841,6 +841,29 @@ def compute_transition_beam_scores(
841841
842842 return transition_scores
843843
844+ def _validate_model_kwargs (self , model_kwargs : Dict [str , Any ]):
845+ """Validates model kwargs for generation. Generate argument typos will also be caught here."""
846+ # Excludes arguments that are handled before calling any model function
847+ if self .config .is_encoder_decoder :
848+ for key in ["decoder_input_ids" ]:
849+ model_kwargs .pop (key , None )
850+
851+ unused_model_args = []
852+ model_args = set (inspect .signature (self .prepare_inputs_for_generation ).parameters )
853+ # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
854+ # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
855+ if "kwargs" in model_args :
856+ model_args |= set (inspect .signature (self .forward ).parameters )
857+ for key , value in model_kwargs .items ():
858+ if value is not None and key not in model_args :
859+ unused_model_args .append (key )
860+
861+ if unused_model_args :
862+ raise ValueError (
863+ f"The following `model_kwargs` are not used by the model: { unused_model_args } (note: typos in the"
864+ " generate arguments will also show up in this list)"
865+ )
866+
844867 @torch .no_grad ()
845868 def generate (
846869 self ,
@@ -1120,6 +1143,9 @@ def generate(
11201143 >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
11211144 ['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
11221145 ```"""
1146+ # 0. Validate model kwargs
1147+ self ._validate_model_kwargs (model_kwargs .copy ())
1148+
11231149 # 1. Set generation parameters if not already defined
11241150 bos_token_id = bos_token_id if bos_token_id is not None else self .config .bos_token_id
11251151 num_beams = num_beams if num_beams is not None else self .config .num_beams
0 commit comments