Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,22 +506,23 @@ def _sample(
# sampling_tensors)


def _get_ranks(x: torch.Tensor, indices: List[int]) -> torch.Tensor:
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.

Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
indices (List[int]): List of chosen token indices.
indices (torch.Tensor): List of chosen token indices.

Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[range(len(x)), indices]
return (x > vals[:, None]).long().sum(1) + 1
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
indices]
return (x > vals[:, None]).long().sum(1).add_(1)


def _get_logprobs(
Expand Down Expand Up @@ -561,12 +562,21 @@ def _get_logprobs(
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)

batched_logprobs_query_seq_indices_gpu = torch.tensor(
batched_logprobs_query_seq_indices, device=logprobs.device)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device)

# Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
batched_logprobs_query_seq_indices_gpu,
batched_logprobs_query_token_indices_gpu
]]

batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices_gpu],
batched_logprobs_query_token_indices_gpu)

# Batched query for logprobs of topk tokens
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
Expand All @@ -578,10 +588,7 @@ def _get_logprobs(
top_logprobs, top_token_ids = None, None

batched_logprobs_query_result = batched_logprobs_query_result.cpu()

batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices],
batched_logprobs_query_token_indices)
batched_ranks_query_result = batched_ranks_query_result.cpu()

# Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
Expand Down