Skip to content

Commit 9196f48

Browse files
authored
Generate: validate model_kwargs on TF (and catch typos in generate arguments) (#18651)
1 parent c5be7ca commit 9196f48

File tree

4 files changed

+214
-139
lines changed

4 files changed

+214
-139
lines changed

src/transformers/generation_tf_utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ def generate(
579579
do_sample = do_sample if do_sample is not None else self.config.do_sample
580580

581581
if do_sample is False or num_beams == 1:
582+
seed = model_kwargs.pop("seed", None)
582583
return self._generate(
583584
input_ids=input_ids,
584585
max_length=max_length,
@@ -601,13 +602,14 @@ def generate(
601602
attention_mask=attention_mask,
602603
decoder_start_token_id=decoder_start_token_id,
603604
use_cache=use_cache,
604-
seed=model_kwargs.pop("seed", None),
605+
seed=seed,
605606
output_scores=output_scores,
606607
output_attentions=output_attentions,
607608
output_hidden_states=output_hidden_states,
608609
return_dict_in_generate=return_dict_in_generate,
609610
forced_bos_token_id=forced_bos_token_id,
610611
forced_eos_token_id=forced_eos_token_id,
612+
**model_kwargs,
611613
)
612614

613615
# We cannot generate if the model does not have a LM head
@@ -1288,6 +1290,29 @@ def adjust_logits_during_generation(
12881290
else:
12891291
return logits
12901292

1293+
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
1294+
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
1295+
# Excludes arguments that are handled before calling any model function
1296+
if self.config.is_encoder_decoder:
1297+
for key in ["decoder_input_ids"]:
1298+
model_kwargs.pop(key, None)
1299+
1300+
unused_model_args = []
1301+
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
1302+
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
1303+
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
1304+
if "kwargs" in model_args:
1305+
model_args |= set(inspect.signature(self.call).parameters)
1306+
for key, value in model_kwargs.items():
1307+
if value is not None and key not in model_args:
1308+
unused_model_args.append(key)
1309+
1310+
if unused_model_args:
1311+
raise ValueError(
1312+
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
1313+
" generate arguments will also show up in this list)"
1314+
)
1315+
12911316
def _generate(
12921317
self,
12931318
input_ids=None,
@@ -1483,6 +1508,9 @@ def _generate(
14831508
# generate sequences without allowing bad_words to be generated
14841509
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
14851510
```"""
1511+
# 0. Validate model kwargs
1512+
self._validate_model_kwargs(model_kwargs.copy())
1513+
14861514
# 1. Set generation parameters if not already defined
14871515
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
14881516
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# coding=utf-8
2+
# Copyright 2022 The HuggingFace Team Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a clone of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import tempfile
17+
import unittest
18+
19+
from transformers import is_tf_available
20+
from transformers.testing_utils import require_tf, slow
21+
22+
23+
if is_tf_available():
24+
import tensorflow as tf
25+
26+
from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering
27+
28+
29+
@require_tf
30+
class UtilsFunctionsTest(unittest.TestCase):
31+
32+
# tests whether the top_k_top_p_filtering function behaves as expected
33+
def test_top_k_top_p_filtering(self):
34+
logits = tf.convert_to_tensor(
35+
[
36+
[
37+
8.2220991, # 3rd highest value; idx. 0
38+
-0.5620044,
39+
5.23229752,
40+
4.0386393,
41+
-6.8798378,
42+
-0.54785802,
43+
-3.2012153,
44+
2.92777176,
45+
1.88171953,
46+
7.35341276, # 5th highest value; idx. 9
47+
8.43207833, # 2nd highest value; idx. 10
48+
-9.85711836,
49+
-5.96209236,
50+
-1.13039161,
51+
-7.1115294,
52+
-0.8369633,
53+
-5.3186408,
54+
7.06427407,
55+
0.81369344,
56+
-0.82023817,
57+
-5.9179796,
58+
0.58813443,
59+
-6.99778438,
60+
4.71551189,
61+
-0.18771637,
62+
7.44020759, # 4th highest value; idx. 25
63+
9.38450987, # 1st highest value; idx. 26
64+
2.12662941,
65+
-9.32562038,
66+
2.35652522,
67+
], # cummulative prob of 5 highest values <= 0.6
68+
[
69+
0.58425518,
70+
4.53139238,
71+
-5.57510464,
72+
-6.28030699,
73+
-7.19529503,
74+
-4.02122551,
75+
1.39337037,
76+
-6.06707057,
77+
1.59480517,
78+
-9.643119,
79+
0.03907799,
80+
0.67231762,
81+
-8.88206726,
82+
6.27115922, # 4th highest value; idx. 13
83+
2.28520723,
84+
4.82767506,
85+
4.30421368,
86+
8.8275313, # 2nd highest value; idx. 17
87+
5.44029958, # 5th highest value; idx. 18
88+
-4.4735794,
89+
7.38579536, # 3rd highest value; idx. 20
90+
-2.91051663,
91+
2.61946077,
92+
-2.5674762,
93+
-9.48959302,
94+
-4.02922645,
95+
-1.35416918,
96+
9.67702323, # 1st highest value; idx. 27
97+
-5.89478553,
98+
1.85370467,
99+
], # cummulative prob of 5 highest values <= 0.6
100+
],
101+
dtype=tf.float32,
102+
)
103+
104+
non_inf_expected_idx = tf.convert_to_tensor(
105+
[[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]],
106+
dtype=tf.int32,
107+
) # expected non filtered idx as noted above
108+
109+
non_inf_expected_output = tf.convert_to_tensor(
110+
[8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023],
111+
dtype=tf.float32,
112+
) # expected non filtered values as noted above
113+
114+
output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4)
115+
116+
non_inf_output = output[output != -float("inf")]
117+
non_inf_idx = tf.cast(
118+
tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))),
119+
dtype=tf.int32,
120+
)
121+
122+
tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12)
123+
tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx)
124+
125+
126+
@require_tf
127+
class TFGenerationIntegrationTests(unittest.TestCase):
128+
@slow
129+
def test_generate_tf_function_export(self):
130+
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
131+
max_length = 2
132+
133+
class DummyModel(tf.Module):
134+
def __init__(self, model):
135+
super(DummyModel, self).__init__()
136+
self.model = model
137+
138+
@tf.function(
139+
input_signature=(
140+
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
141+
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
142+
),
143+
jit_compile=True,
144+
)
145+
def serving(self, input_ids, attention_mask):
146+
outputs = self.model.generate(
147+
input_ids=input_ids,
148+
attention_mask=attention_mask,
149+
max_new_tokens=max_length,
150+
return_dict_in_generate=True,
151+
)
152+
return {"sequences": outputs["sequences"]}
153+
154+
dummy_input_ids = [[2, 0], [102, 103]]
155+
dummy_attention_masks = [[1, 0], [1, 1]]
156+
dummy_model = DummyModel(model=test_model)
157+
with tempfile.TemporaryDirectory() as tmp_dir:
158+
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
159+
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
160+
for batch_size in range(1, len(dummy_input_ids) + 1):
161+
inputs = {
162+
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
163+
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
164+
}
165+
tf_func_outputs = serving_func(**inputs)["sequences"]
166+
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
167+
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
168+
169+
def test_validate_generation_inputs(self):
170+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
171+
model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
172+
173+
encoder_input_str = "Hello world"
174+
input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids
175+
176+
# typos are quickly detected (the correct argument is `do_sample`)
177+
with self.assertRaisesRegex(ValueError, "do_samples"):
178+
model.generate(input_ids, do_samples=True)
179+
180+
# arbitrary arguments that will not be used anywhere are also not accepted
181+
with self.assertRaisesRegex(ValueError, "foo"):
182+
fake_model_kwargs = {"foo": "bar"}
183+
model.generate(input_ids, **fake_model_kwargs)

tests/generation/test_generation_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,8 +2704,8 @@ def test_constrained_beam_search_mixin_type_checks(self):
27042704
model.generate(input_ids, force_words_ids=[[[-1]]])
27052705

27062706
def test_validate_generation_inputs(self):
2707-
tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")
2708-
model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random")
2707+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
2708+
model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
27092709

27102710
encoder_input_str = "Hello world"
27112711
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

0 commit comments

Comments
 (0)