Skip to content

Commit a53e587

Browse files
committed
Reformatting logits_process.py
1 parent 44006fa commit a53e587

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/transformers/generation/logits_process.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
652652
tau = distribution.entropy() * self.top_h
653653

654654
# grow the kept set until the stopping rule triggers
655-
cumulative_entropy = - distribution.probs[torch.tensor([0], device=top_probs.device)] * distribution.log_prob(torch.tensor([0], device=top_probs.device)) # -top_probs[0] * torch.log2(top_probs[0])
655+
cumulative_entropy = -distribution.probs[
656+
torch.tensor([0], device=top_probs.device)
657+
] * distribution.log_prob(
658+
torch.tensor([0], device=top_probs.device)
659+
) # -top_probs[0] * torch.log2(top_probs[0])
656660
chosen = []
657661
ind = 0
658662
for idx, p in zip(top_idx, top_probs):
@@ -661,7 +665,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
661665
if ind == len(top_probs):
662666
break
663667
# update running sums for current prefix
664-
cumulative_entropy = cumulative_entropy - distribution.probs[torch.tensor([ind], device=top_probs.device)] * distribution.log_prob(torch.tensor([ind], device=top_probs.device))
668+
cumulative_entropy = cumulative_entropy - distribution.probs[
669+
torch.tensor([ind], device=top_probs.device)
670+
] * distribution.log_prob(torch.tensor([ind], device=top_probs.device))
665671

666672
# entropy difference term
667673
if cumulative_entropy > tau:

0 commit comments

Comments
 (0)