Skip to content

Commit 6b6251b

Browse files
committed
minor fixes
1 parent 4090090 commit 6b6251b

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

generate_sequences/generate.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def sample_next_tokens(self, logits, num_tokens=1, min_tokens_to_keep=2):
7878
if min_tokens_to_keep > 1:
7979
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
8080
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
81-
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
82-
sorted_indices_to_remove[:, 0] = 0
81+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
82+
sorted_indices_to_remove[..., 0] = 0
8383
indices_to_remove = sorted_indices_to_remove.scatter(
8484
1,
8585
sorted_indices,
@@ -90,16 +90,13 @@ def sample_next_tokens(self, logits, num_tokens=1, min_tokens_to_keep=2):
9090
# for i in range(logits.size(0)):
9191
# indices_to_remove = sorted_indices[i, sorted_indices_to_remove[i]]
9292
# logits[i, indices_to_remove] = -float("Inf")
93-
logits = F.log_softmax(logits, dim=-1)
93+
logits = F.log_softmax(logits, dim=-1) / self.temperature
9494
if self.multinomial_sampling:
9595
next_tokens = torch.multinomial(
9696
torch.exp(logits),
9797
num_samples=num_tokens,
9898
)
9999
logits = logits.gather(-1, next_tokens)
100-
# sort the sampled vector to make sure that the first num_beams samples are the best
101-
logits, next_scores_indices = torch.sort(logits, descending=True, dim=1)
102-
next_tokens = torch.gather(next_tokens, -1, next_scores_indices)
103100
else:
104101
logits, next_tokens = torch.topk(logits, num_tokens)
105102
return logits, next_tokens
@@ -125,7 +122,7 @@ def generate(self, inputs: Union[List[torch.Tensor], List[str]]) -> List[torch.T
125122
if finished_mask.all():
126123
break # Stop if all sequences are finished
127124
batch_outputs = self.generation_forward(batch_inputs, decoder_inputs[:, :step])
128-
logits = batch_outputs[:, -1, :] / self.temperature
125+
logits = batch_outputs[:, -1, :]
129126
_, next_tokens = self.sample_next_tokens(logits)
130127
next_tokens = next_tokens.squeeze()
131128
not_finished = ~finished_mask
@@ -242,7 +239,7 @@ def generate(self, inputs: Union[List[torch.Tensor], List[str]]) -> List[torch.T
242239
[sample_best_nodes[k].tokens for sample_best_nodes in batch_best_nodes]
243240
).to(self.device)
244241
batch_outputs = self.generation_forward(batch, decoder_input_ids)
245-
logits = batch_outputs[:, -1, :] / self.temperature
242+
logits = batch_outputs[:, -1, :]
246243
logits, next_tokens = self.sample_next_tokens(
247244
logits, num_tokens=self.beam_width
248245
)

0 commit comments

Comments
 (0)