Skip to content

Commit ed1924e

Browse files
authored
Generate: validate model_kwargs (and catch typos in generate arguments) (#18261)
* validate generate model_kwargs * generate tests -- not all models have an attn mask
1 parent 2156619 commit ed1924e

File tree

2 files changed

+91
-48
lines changed

2 files changed

+91
-48
lines changed

src/transformers/generation_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,29 @@ def compute_transition_beam_scores(
841841

842842
return transition_scores
843843

844+
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
845+
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
846+
# Excludes arguments that are handled before calling any model function
847+
if self.config.is_encoder_decoder:
848+
for key in ["decoder_input_ids"]:
849+
model_kwargs.pop(key, None)
850+
851+
unused_model_args = []
852+
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
853+
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
854+
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
855+
if "kwargs" in model_args:
856+
model_args |= set(inspect.signature(self.forward).parameters)
857+
for key, value in model_kwargs.items():
858+
if value is not None and key not in model_args:
859+
unused_model_args.append(key)
860+
861+
if unused_model_args:
862+
raise ValueError(
863+
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
864+
" generate arguments will also show up in this list)"
865+
)
866+
844867
@torch.no_grad()
845868
def generate(
846869
self,
@@ -1120,6 +1143,9 @@ def generate(
11201143
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
11211144
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
11221145
```"""
1146+
# 0. Validate model kwargs
1147+
self._validate_model_kwargs(model_kwargs.copy())
1148+
11231149
# 1. Set generation parameters if not already defined
11241150
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
11251151
num_beams = num_beams if num_beams is not None else self.config.num_beams

0 commit comments

Comments
 (0)