@@ -212,6 +212,7 @@ def process(
212212 next_indices : torch .LongTensor ,
213213 pad_token_id : Optional [int ] = None ,
214214 eos_token_id : Optional [int ] = None ,
215+ beam_indices : Optional [torch .LongTensor ] = None ,
215216 ) -> Tuple [torch .Tensor ]:
216217 cur_len = input_ids .shape [- 1 ]
217218 batch_size = len (self ._beam_hyps )
@@ -256,9 +257,16 @@ def process(
256257 is_beam_token_worse_than_top_num_beams = beam_token_rank >= self .group_size
257258 if is_beam_token_worse_than_top_num_beams :
258259 continue
260+ if beam_indices is not None :
261+ beam_index = beam_indices [batch_beam_idx ]
262+ beam_index = beam_index + (next_index ,)
263+ else :
264+ beam_index = None
265+
259266 beam_hyp .add (
260267 input_ids [batch_beam_idx ].clone (),
261268 next_score .item (),
269+ beam_indices = beam_index ,
262270 )
263271 else :
264272 # add next predicted token since it is not eos_token
@@ -299,6 +307,7 @@ def finalize(
299307 max_length : int ,
300308 pad_token_id : Optional [int ] = None ,
301309 eos_token_id : Optional [int ] = None ,
310+ beam_indices : Optional [torch .LongTensor ] = None ,
302311 ) -> Tuple [torch .LongTensor ]:
303312 batch_size = len (self ._beam_hyps )
304313
@@ -313,11 +322,13 @@ def finalize(
313322 batch_beam_idx = batch_idx * self .num_beams + beam_id
314323 final_score = final_beam_scores [batch_beam_idx ].item ()
315324 final_tokens = input_ids [batch_beam_idx ]
316- beam_hyp .add (final_tokens , final_score )
325+ beam_index = beam_indices [batch_beam_idx ] if beam_indices is not None else None
326+ beam_hyp .add (final_tokens , final_score , beam_indices = beam_index )
317327
318328 # select the best hypotheses
319329 sent_lengths = input_ids .new (batch_size * self .num_beam_hyps_to_keep )
320330 best = []
331+ best_indices = []
321332 best_scores = torch .zeros (batch_size * self .num_beam_hyps_to_keep , device = self .device , dtype = torch .float32 )
322333
323334 # retrieve best hypotheses
@@ -327,30 +338,50 @@ def finalize(
327338 best_hyp_tuple = sorted_hyps .pop ()
328339 best_score = best_hyp_tuple [0 ]
329340 best_hyp = best_hyp_tuple [1 ]
341+ best_index = best_hyp_tuple [2 ]
330342 sent_lengths [self .num_beam_hyps_to_keep * i + j ] = len (best_hyp )
331343
332- # append to lists
344+ # append hyp to lists
333345 best .append (best_hyp )
346+
347+ # append indices to list
348+ best_indices .append (best_index )
349+
334350 best_scores [i * self .num_beam_hyps_to_keep + j ] = best_score
335351
336352 # prepare for adding eos
337353 sent_lengths_max = sent_lengths .max ().item () + 1
338354 sent_max_len = min (sent_lengths_max , max_length ) if max_length is not None else sent_lengths_max
339355 decoded : torch .LongTensor = input_ids .new (batch_size * self .num_beam_hyps_to_keep , sent_max_len )
356+
357+ if len (best_indices ) > 0 and best_indices [0 ] is not None :
358+ indices : torch .LongTensor = input_ids .new (batch_size * self .num_beam_hyps_to_keep , sent_max_len )
359+ else :
360+ indices = None
361+
340362 # shorter batches are padded if needed
341363 if sent_lengths .min ().item () != sent_lengths .max ().item ():
342364 assert pad_token_id is not None , "`pad_token_id` has to be defined"
343365 decoded .fill_ (pad_token_id )
366+
367+ if indices is not None :
368+ indices .fill_ (- 1 )
369+
344370 # fill with hypotheses and eos_token_id if the latter fits in
345- for i , hypo in enumerate (best ):
371+ for i , ( hypo , best_idx ) in enumerate (zip ( best , best_indices ) ):
346372 decoded [i , : sent_lengths [i ]] = hypo
373+
374+ if indices is not None :
375+ indices [i , : len (best_idx )] = torch .tensor (best_idx )
376+
347377 if sent_lengths [i ] < sent_max_len :
348378 decoded [i , sent_lengths [i ]] = eos_token_id
349379
350380 return UserDict (
351381 {
352382 "sequences" : decoded ,
353383 "sequence_scores" : best_scores ,
384+ "beam_indices" : indices ,
354385 }
355386 )
356387
@@ -789,6 +820,7 @@ def finalize(
789820
790821 # prepare for adding eos
791822 sent_lengths_max = sent_lengths .max ().item () + 1
823+
792824 sent_max_len = min (sent_lengths_max , max_length ) if max_length is not None else sent_lengths_max
793825 decoded : torch .LongTensor = input_ids .new (batch_size * self .num_beam_hyps_to_keep , sent_max_len )
794826 # shorter batches are padded if needed
@@ -801,6 +833,7 @@ def finalize(
801833 decoded [i , : sent_lengths [i ]] = hypo
802834 if sent_lengths [i ] < sent_max_len :
803835 decoded [i , sent_lengths [i ]] = eos_token_id
836+
804837 return UserDict (
805838 {
806839 "sequences" : decoded ,
@@ -826,15 +859,15 @@ def __len__(self):
826859 """
827860 return len (self .beams )
828861
829- def add (self , hyp : torch .LongTensor , sum_logprobs : float ):
862+ def add (self , hyp : torch .LongTensor , sum_logprobs : float , beam_indices : Optional [ torch . LongTensor ] = None ):
830863 """
831864 Add a new hypothesis to the list.
832865 """
833866 score = sum_logprobs / (hyp .shape [- 1 ] ** self .length_penalty )
834867 if len (self ) < self .num_beams or score > self .worst_score :
835- self .beams .append ((score , hyp ))
868+ self .beams .append ((score , hyp , beam_indices ))
836869 if len (self ) > self .num_beams :
837- sorted_next_scores = sorted ([(s , idx ) for idx , (s , _ ) in enumerate (self .beams )])
870+ sorted_next_scores = sorted ([(s , idx ) for idx , (s , _ , _ ) in enumerate (self .beams )])
838871 del self .beams [sorted_next_scores [0 ][1 ]]
839872 self .worst_score = sorted_next_scores [1 ][0 ]
840873 else :
0 commit comments