Skip to content
150 changes: 148 additions & 2 deletions tests/v1/sample/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]],
def create_sampling_metadata(
all_greedy: bool,
temperature: Optional[torch.Tensor] = None,
top_k: Optional[torch.Tensor] = None,
top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set
Expand All @@ -52,8 +54,8 @@ def create_sampling_metadata(
temperature=temperature,
all_greedy=all_greedy,
all_random=not all_greedy,
top_p=None,
top_k=None,
top_p=top_p,
top_k=top_k,
min_p=torch.empty(1, ),
generators=generators,
max_num_logprobs=0,
Expand Down Expand Up @@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf(
density=True)

return hist.hist


def _test_masked_logits(
rejection_sampler,
batch_size: int,
num_draft_tokens: int,
vocab_size: int,
target_logits: torch.Tensor,
unmasked_indices: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
# Set up test parameters
num_tokens = batch_size * num_draft_tokens

# Create random draft probabilities.
draft_probs = torch.rand((num_tokens, vocab_size),
dtype=torch.float32,
device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)

# Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1)
draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens)
draft_token_ids = draft_token_ids.tolist()

# Bonus tokens not used but required
bonus_token_ids = torch.zeros((batch_size, 1),
dtype=torch.int64,
device=DEVICE)

# Create spec decode metadata
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids,
device=DEVICE,
)

# Run rejection sampling
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)

# Remove bonus tokens and reshape
output_token_ids = output_token_ids[:, :-1].flatten().tolist()

# Check that all sampled tokens are within the unmasked indices.
for i in range(num_tokens):
token_id = output_token_ids[i]
if token_id == PLACEHOLDER_TOKEN_ID:
continue
assert token_id in unmasked_indices[i]


@pytest.mark.parametrize("top_k", [1, 5, 99])
def test_top_k(rejection_sampler, top_k):
"""Test rejection sampling with top-k sampling"""
vocab_size = 100
batch_size = 100
num_draft_tokens = 3
num_tokens = batch_size * num_draft_tokens

# Randomly create top-k indices.
top_k_indices = [
torch.randperm(vocab_size, device=DEVICE)[:top_k]
for _ in range(num_tokens)
]
top_k_indices = torch.stack(top_k_indices)

# Create logits with the uniform distribution.
target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE)

# Increment the logits for top-k indices, a little bit more than the other
# ones. If the masking is effective, the non-topk indices will never be
# sampled despite the small difference in logits.
for i in range(num_tokens):
target_logits[i, top_k_indices[i]] += 0.1

# Create sampling metadata
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_k=torch.tensor([top_k] * batch_size,
device=DEVICE,
dtype=torch.int64),
)

_test_masked_logits(
rejection_sampler,
batch_size=batch_size,
num_draft_tokens=num_draft_tokens,
vocab_size=vocab_size,
target_logits=target_logits,
unmasked_indices=top_k_indices,
sampling_metadata=sampling_metadata,
)


@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
def test_top_p(rejection_sampler, top_p):
"""Test rejection sampling with top-p sampling"""
vocab_size = 100
batch_size = 100
num_draft_tokens = 3
num_tokens = batch_size * num_draft_tokens

# Create logits with the uniform distribution.
target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
rescaled_logits = target_logits / temperature

logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False)
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - top_p
# at least one
top_p_mask[:, -1] = False

# Get the top-p indices.
top_p_indices = []
for i in range(num_tokens):
top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist())

# Create sampling metadata
sampling_metadata = create_sampling_metadata(
all_greedy=False,
temperature=temperature,
top_p=torch.tensor([top_p] * batch_size,
device=DEVICE,
dtype=torch.float32),
)

_test_masked_logits(
rejection_sampler,
batch_size=batch_size,
num_draft_tokens=num_draft_tokens,
vocab_size=vocab_size,
target_logits=target_logits,
unmasked_indices=top_p_indices,
sampling_metadata=sampling_metadata,
)
83 changes: 70 additions & 13 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.ops.utils import compiled_softmax
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

Expand Down Expand Up @@ -245,25 +246,81 @@ def compute_probs(
return logits

num_tokens = logits.shape[0]
batch_size = cu_num_draft_tokens.shape[0]
expanded_temperature = torch.empty(
(num_tokens, 1),
dtype=torch.float32,
device=logits.device,
)
expand_kernel[(batch_size, )](
expanded_temperature,
temperature = expand_batch_to_tokens(
sampling_metadata.temperature,
cu_num_draft_tokens,
GREEDY_TEMPERATURE, # replace_from
1, # replace_to
MAX_NUM_TOKENS=MAX_SPEC_LEN,
num_warps=1,
num_tokens,
replace_from=GREEDY_TEMPERATURE,
replace_to=1,
)
output_prob = compiled_softmax(logits, expanded_temperature)
# TODO(woosuk): Consider using in-place op to reduce memory usage.
logits = logits / temperature.unsqueeze(-1)

# Get expanded top_k and top_p tensors.
top_k = None
if sampling_metadata.top_k is not None:
top_k = expand_batch_to_tokens(
sampling_metadata.top_k,
cu_num_draft_tokens,
num_tokens,
)
top_p = None
if sampling_metadata.top_p is not None:
top_p = expand_batch_to_tokens(
sampling_metadata.top_p,
cu_num_draft_tokens,
num_tokens,
)

# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits = apply_top_k_top_p(logits, top_k, top_p)

output_prob = compiled_softmax(logits)
return output_prob


def expand_batch_to_tokens(
x: torch.Tensor, # [batch_size]
cu_num_tokens: torch.Tensor, # [batch_size]
num_tokens: int,
replace_from: int = 0,
replace_to: int = 0,
) -> torch.Tensor:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.

For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].

Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
expand_kernel[(batch_size, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
num_warps=1,
)
return expanded_x


def generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
Expand Down
5 changes: 1 addition & 4 deletions vllm/v1/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@


def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs:
# Spec decode doesn't support top_p/top_k sampling.
return False
elif req_id in input_batch.min_p_reqs:
if req_id in input_batch.min_p_reqs:
# Spec decode doesn't support min_p sampling.
return False
elif (req_id in input_batch.frequency_penalties_reqs
Expand Down