Skip to content

Commit 123c201

Browse files
patrickvonplatenelusenji
authored andcommitted
[Test OPT] Add batch generation test opt (huggingface#17359)
* up * up
1 parent da211f8 commit 123c201

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

tests/models/opt/test_modeling_opt.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,47 @@ def test_generation_pre_attn_layer_norm(self):
366366

367367
self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
368368

369+
def test_batch_generation(self):
370+
model_id = "facebook/opt-350m"
371+
372+
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
373+
model = OPTForCausalLM.from_pretrained(model_id)
374+
model.to(torch_device)
375+
376+
tokenizer.padding_side = "left"
377+
378+
# use different length sentences to test batching
379+
sentences = [
380+
"Hello, my dog is a little",
381+
"Today, I",
382+
]
383+
384+
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
385+
input_ids = inputs["input_ids"].to(torch_device)
386+
387+
outputs = model.generate(
388+
input_ids=input_ids,
389+
attention_mask=inputs["attention_mask"].to(torch_device),
390+
)
391+
392+
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
393+
output_non_padded = model.generate(input_ids=inputs_non_padded)
394+
395+
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
396+
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
397+
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
398+
399+
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
400+
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
401+
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
402+
403+
expected_output_sentence = [
404+
"Hello, my dog is a little bit of a dork.\nI'm a little bit",
405+
"Today, I was in the middle of a conversation with a friend about the",
406+
]
407+
self.assertListEqual(expected_output_sentence, batch_out_sentence)
408+
self.assertListEqual(batch_out_sentence, [non_padded_sentence, padded_sentence])
409+
369410
def test_generation_post_attn_layer_norm(self):
370411
model_id = "facebook/opt-350m"
371412

0 commit comments

Comments
 (0)