@@ -366,6 +366,47 @@ def test_generation_pre_attn_layer_norm(self):
366366
367367 self .assertListEqual (predicted_outputs , EXPECTED_OUTPUTS )
368368
369+ def test_batch_generation (self ):
370+ model_id = "facebook/opt-350m"
371+
372+ tokenizer = GPT2Tokenizer .from_pretrained (model_id )
373+ model = OPTForCausalLM .from_pretrained (model_id )
374+ model .to (torch_device )
375+
376+ tokenizer .padding_side = "left"
377+
378+ # use different length sentences to test batching
379+ sentences = [
380+ "Hello, my dog is a little" ,
381+ "Today, I" ,
382+ ]
383+
384+ inputs = tokenizer (sentences , return_tensors = "pt" , padding = True )
385+ input_ids = inputs ["input_ids" ].to (torch_device )
386+
387+ outputs = model .generate (
388+ input_ids = input_ids ,
389+ attention_mask = inputs ["attention_mask" ].to (torch_device ),
390+ )
391+
392+ inputs_non_padded = tokenizer (sentences [0 ], return_tensors = "pt" ).input_ids .to (torch_device )
393+ output_non_padded = model .generate (input_ids = inputs_non_padded )
394+
395+ num_paddings = inputs_non_padded .shape [- 1 ] - inputs ["attention_mask" ][- 1 ].long ().sum ().cpu ().item ()
396+ inputs_padded = tokenizer (sentences [1 ], return_tensors = "pt" ).input_ids .to (torch_device )
397+ output_padded = model .generate (input_ids = inputs_padded , max_length = model .config .max_length - num_paddings )
398+
399+ batch_out_sentence = tokenizer .batch_decode (outputs , skip_special_tokens = True )
400+ non_padded_sentence = tokenizer .decode (output_non_padded [0 ], skip_special_tokens = True )
401+ padded_sentence = tokenizer .decode (output_padded [0 ], skip_special_tokens = True )
402+
403+ expected_output_sentence = [
404+ "Hello, my dog is a little bit of a dork.\n I'm a little bit" ,
405+ "Today, I was in the middle of a conversation with a friend about the" ,
406+ ]
407+ self .assertListEqual (expected_output_sentence , batch_out_sentence )
408+ self .assertListEqual (batch_out_sentence , [non_padded_sentence , padded_sentence ])
409+
369410 def test_generation_post_attn_layer_norm (self ):
370411 model_id = "facebook/opt-350m"
371412
0 commit comments