@@ -652,7 +652,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
652652 tau = distribution .entropy () * self .top_h
653653
654654 # grow the kept set until the stopping rule triggers
655- cumulative_entropy = - distribution .probs [torch .tensor ([0 ], device = top_probs .device )] * distribution .log_prob (torch .tensor ([0 ], device = top_probs .device )) # -top_probs[0] * torch.log2(top_probs[0])
655+ cumulative_entropy = - distribution .probs [
656+ torch .tensor ([0 ], device = top_probs .device )
657+ ] * distribution .log_prob (
658+ torch .tensor ([0 ], device = top_probs .device )
659+ ) # -top_probs[0] * torch.log2(top_probs[0])
656660 chosen = []
657661 ind = 0
658662 for idx , p in zip (top_idx , top_probs ):
@@ -661,7 +665,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
661665 if ind == len (top_probs ):
662666 break
663667 # update running sums for current prefix
664- cumulative_entropy = cumulative_entropy - distribution .probs [torch .tensor ([ind ], device = top_probs .device )] * distribution .log_prob (torch .tensor ([ind ], device = top_probs .device ))
668+ cumulative_entropy = cumulative_entropy - distribution .probs [
669+ torch .tensor ([ind ], device = top_probs .device )
670+ ] * distribution .log_prob (torch .tensor ([ind ], device = top_probs .device ))
665671
666672 # entropy difference term
667673 if cumulative_entropy > tau :
0 commit comments