@@ -506,22 +506,23 @@ def _sample(
506506 # sampling_tensors)
507507
508508
509- def _get_ranks (x : torch .Tensor , indices : List [ int ] ) -> torch .Tensor :
509+ def _get_ranks (x : torch .Tensor , indices : torch . Tensor ) -> torch .Tensor :
510510 """
511511 This function calculates the ranks of the chosen tokens in a logprob tensor.
512512
513513 Args:
514514 x (torch.Tensor): 2D logprob tensor of shape (N, M)
515515 where N is the no. of tokens and M is the vocab dim.
516- indices (List[int] ): List of chosen token indices.
516+ indices (torch.Tensor ): List of chosen token indices.
517517
518518 Returns:
519519 torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
520520 Each element in the returned tensor represents the rank
521521 of the chosen token in the input logprob tensor.
522522 """
523- vals = x [range (len (x )), indices ]
524- return (x > vals [:, None ]).long ().sum (1 ) + 1
523+ vals = x [torch .arange (0 , len (x ), device = x .device , dtype = indices .dtype ),
524+ indices ]
525+ return (x > vals [:, None ]).long ().sum (1 ).add_ (1 )
525526
526527
527528def _get_logprobs (
@@ -561,12 +562,21 @@ def _get_logprobs(
561562 sample_idx += num_parent_seqs
562563 assert sample_idx == logprobs .size (0 )
563564
565+ batched_logprobs_query_seq_indices_gpu = torch .tensor (
566+ batched_logprobs_query_seq_indices , device = logprobs .device )
567+ batched_logprobs_query_token_indices_gpu = torch .tensor (
568+ batched_logprobs_query_token_indices , device = logprobs .device )
569+
564570 # Batched query for logprobs of selected token
565571 batched_logprobs_query_result = logprobs [[
566- batched_logprobs_query_seq_indices ,
567- batched_logprobs_query_token_indices
572+ batched_logprobs_query_seq_indices_gpu ,
573+ batched_logprobs_query_token_indices_gpu
568574 ]]
569575
576+ batched_ranks_query_result = _get_ranks (
577+ logprobs [batched_logprobs_query_seq_indices_gpu ],
578+ batched_logprobs_query_token_indices_gpu )
579+
570580 # Batched query for logprobs of topk tokens
571581 if largest_num_logprobs > 0 :
572582 top_logprobs , top_token_ids = torch .topk (logprobs ,
@@ -578,10 +588,7 @@ def _get_logprobs(
578588 top_logprobs , top_token_ids = None , None
579589
580590 batched_logprobs_query_result = batched_logprobs_query_result .cpu ()
581-
582- batched_ranks_query_result = _get_ranks (
583- logprobs [batched_logprobs_query_seq_indices ],
584- batched_logprobs_query_token_indices )
591+ batched_ranks_query_result = batched_ranks_query_result .cpu ()
585592
586593 # Gather results
587594 result_prompt_logprobs : List [Optional [PromptLogprobs ]] = []
0 commit comments