@@ -78,8 +78,8 @@ def sample_next_tokens(self, logits, num_tokens=1, min_tokens_to_keep=2):
7878 if min_tokens_to_keep > 1 :
7979 # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
8080 sorted_indices_to_remove [..., :min_tokens_to_keep ] = 0
81- sorted_indices_to_remove [: , 1 :] = sorted_indices_to_remove [: , :- 1 ].clone ()
82- sorted_indices_to_remove [: , 0 ] = 0
81+ sorted_indices_to_remove [... , 1 :] = sorted_indices_to_remove [... , :- 1 ].clone ()
82+ sorted_indices_to_remove [... , 0 ] = 0
8383 indices_to_remove = sorted_indices_to_remove .scatter (
8484 1 ,
8585 sorted_indices ,
@@ -90,16 +90,13 @@ def sample_next_tokens(self, logits, num_tokens=1, min_tokens_to_keep=2):
9090 # for i in range(logits.size(0)):
9191 # indices_to_remove = sorted_indices[i, sorted_indices_to_remove[i]]
9292 # logits[i, indices_to_remove] = -float("Inf")
93- logits = F .log_softmax (logits , dim = - 1 )
93+ logits = F .log_softmax (logits , dim = - 1 ) / self . temperature
9494 if self .multinomial_sampling :
9595 next_tokens = torch .multinomial (
9696 torch .exp (logits ),
9797 num_samples = num_tokens ,
9898 )
9999 logits = logits .gather (- 1 , next_tokens )
100- # sort the sampled vector to make sure that the first num_beams samples are the best
101- logits , next_scores_indices = torch .sort (logits , descending = True , dim = 1 )
102- next_tokens = torch .gather (next_tokens , - 1 , next_scores_indices )
103100 else :
104101 logits , next_tokens = torch .topk (logits , num_tokens )
105102 return logits , next_tokens
@@ -125,7 +122,7 @@ def generate(self, inputs: Union[List[torch.Tensor], List[str]]) -> List[torch.T
125122 if finished_mask .all ():
126123 break # Stop if all sequences are finished
127124 batch_outputs = self .generation_forward (batch_inputs , decoder_inputs [:, :step ])
128- logits = batch_outputs [:, - 1 , :] / self . temperature
125+ logits = batch_outputs [:, - 1 , :]
129126 _ , next_tokens = self .sample_next_tokens (logits )
130127 next_tokens = next_tokens .squeeze ()
131128 not_finished = ~ finished_mask
@@ -242,7 +239,7 @@ def generate(self, inputs: Union[List[torch.Tensor], List[str]]) -> List[torch.T
242239 [sample_best_nodes [k ].tokens for sample_best_nodes in batch_best_nodes ]
243240 ).to (self .device )
244241 batch_outputs = self .generation_forward (batch , decoder_input_ids )
245- logits = batch_outputs [:, - 1 , :] / self . temperature
242+ logits = batch_outputs [:, - 1 , :]
246243 logits , next_tokens = self .sample_next_tokens (
247244 logits , num_tokens = self .beam_width
248245 )
0 commit comments