diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 8c423e367ef5..cbdb0b910d1d 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]], def create_sampling_metadata( all_greedy: bool, temperature: Optional[torch.Tensor] = None, + top_k: Optional[torch.Tensor] = None, + top_p: Optional[torch.Tensor] = None, generators: Optional[dict[int, Any]] = None, ) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set @@ -52,8 +54,8 @@ def create_sampling_metadata( temperature=temperature, all_greedy=all_greedy, all_random=not all_greedy, - top_p=None, - top_k=None, + top_p=top_p, + top_k=top_k, min_p=torch.empty(1, ), generators=generators, max_num_logprobs=0, @@ -462,3 +464,147 @@ def estimate_rejection_sampling_pdf( density=True) return hist.hist + + +def _test_masked_logits( + rejection_sampler, + batch_size: int, + num_draft_tokens: int, + vocab_size: int, + target_logits: torch.Tensor, + unmasked_indices: torch.Tensor, + sampling_metadata: SamplingMetadata, +): + # Set up test parameters + num_tokens = batch_size * num_draft_tokens + + # Create random draft probabilities. + draft_probs = torch.rand((num_tokens, vocab_size), + dtype=torch.float32, + device=DEVICE) + draft_probs = F.softmax(draft_probs, dim=-1) + + # Randomly sample draft token ids from draft probs + draft_token_ids = torch.multinomial(draft_probs, num_samples=1) + draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens) + draft_token_ids = draft_token_ids.tolist() + + # Bonus tokens not used but required + bonus_token_ids = torch.zeros((batch_size, 1), + dtype=torch.int64, + device=DEVICE) + + # Create spec decode metadata + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids, + device=DEVICE, + ) + + # Run rejection sampling + output_token_ids = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) + + # Remove bonus tokens and reshape + output_token_ids = output_token_ids[:, :-1].flatten().tolist() + + # Check that all sampled tokens are within the unmasked indices. + for i in range(num_tokens): + token_id = output_token_ids[i] + if token_id == PLACEHOLDER_TOKEN_ID: + continue + assert token_id in unmasked_indices[i] + + +@pytest.mark.parametrize("top_k", [1, 5, 99]) +def test_top_k(rejection_sampler, top_k): + """Test rejection sampling with top-k sampling""" + vocab_size = 100 + batch_size = 100 + num_draft_tokens = 3 + num_tokens = batch_size * num_draft_tokens + + # Randomly create top-k indices. + top_k_indices = [ + torch.randperm(vocab_size, device=DEVICE)[:top_k] + for _ in range(num_tokens) + ] + top_k_indices = torch.stack(top_k_indices) + + # Create logits with the uniform distribution. + target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE) + + # Increment the logits for top-k indices, a little bit more than the other + # ones. If the masking is effective, the non-topk indices will never be + # sampled despite the small difference in logits. + for i in range(num_tokens): + target_logits[i, top_k_indices[i]] += 0.1 + + # Create sampling metadata + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata( + all_greedy=False, + temperature=temperature, + top_k=torch.tensor([top_k] * batch_size, + device=DEVICE, + dtype=torch.int64), + ) + + _test_masked_logits( + rejection_sampler, + batch_size=batch_size, + num_draft_tokens=num_draft_tokens, + vocab_size=vocab_size, + target_logits=target_logits, + unmasked_indices=top_k_indices, + sampling_metadata=sampling_metadata, + ) + + +@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99]) +def test_top_p(rejection_sampler, top_p): + """Test rejection sampling with top-p sampling""" + vocab_size = 100 + batch_size = 100 + num_draft_tokens = 3 + num_tokens = batch_size * num_draft_tokens + + # Create logits with the uniform distribution. + target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE) + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + rescaled_logits = target_logits / temperature + + logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False) + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - top_p + # at least one + top_p_mask[:, -1] = False + + # Get the top-p indices. + top_p_indices = [] + for i in range(num_tokens): + top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist()) + + # Create sampling metadata + sampling_metadata = create_sampling_metadata( + all_greedy=False, + temperature=temperature, + top_p=torch.tensor([top_p] * batch_size, + device=DEVICE, + dtype=torch.float32), + ) + + _test_masked_logits( + rejection_sampler, + batch_size=batch_size, + num_draft_tokens=num_draft_tokens, + vocab_size=vocab_size, + target_logits=target_logits, + unmasked_indices=top_p_indices, + sampling_metadata=sampling_metadata, + ) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 6284ae4b490a..c8327f36a585 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.utils import compiled_softmax from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -245,25 +246,81 @@ def compute_probs( return logits num_tokens = logits.shape[0] - batch_size = cu_num_draft_tokens.shape[0] - expanded_temperature = torch.empty( - (num_tokens, 1), - dtype=torch.float32, - device=logits.device, - ) - expand_kernel[(batch_size, )]( - expanded_temperature, + temperature = expand_batch_to_tokens( sampling_metadata.temperature, cu_num_draft_tokens, - GREEDY_TEMPERATURE, # replace_from - 1, # replace_to - MAX_NUM_TOKENS=MAX_SPEC_LEN, - num_warps=1, + num_tokens, + replace_from=GREEDY_TEMPERATURE, + replace_to=1, ) - output_prob = compiled_softmax(logits, expanded_temperature) + # TODO(woosuk): Consider using in-place op to reduce memory usage. + logits = logits / temperature.unsqueeze(-1) + + # Get expanded top_k and top_p tensors. + top_k = None + if sampling_metadata.top_k is not None: + top_k = expand_batch_to_tokens( + sampling_metadata.top_k, + cu_num_draft_tokens, + num_tokens, + ) + top_p = None + if sampling_metadata.top_p is not None: + top_p = expand_batch_to_tokens( + sampling_metadata.top_p, + cu_num_draft_tokens, + num_tokens, + ) + + # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, + # which is slow for large vocab sizes. This may cause performance issues. + logits = apply_top_k_top_p(logits, top_k, top_p) + + output_prob = compiled_softmax(logits) return output_prob +def expand_batch_to_tokens( + x: torch.Tensor, # [batch_size] + cu_num_tokens: torch.Tensor, # [batch_size] + num_tokens: int, + replace_from: int = 0, + replace_to: int = 0, +) -> torch.Tensor: + """Expand [batch_size] tensor to [num_tokens] tensor based on the number of + tokens per batch in cu_num_tokens. + + For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then + num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. + + Args: + x: [batch_size] tensor to expand. + cu_num_tokens: [batch_size] tensor containing the cumulative number of + tokens per batch. Each element represents the total number of + tokens up to and including that batch. + num_tokens: Total number of tokens. + replace_from: int = 0 + Value to be replaced if it is found in x. + replace_to: int = 0 + Value to replace with when replace_from is found. + Returns: + expanded_x: [num_tokens] tensor. + """ + batch_size = x.shape[0] + assert cu_num_tokens.shape[0] == batch_size + expanded_x = x.new_empty(num_tokens) + expand_kernel[(batch_size, )]( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + num_warps=1, + ) + return expanded_x + + def generate_uniform_probs( num_tokens: int, num_draft_tokens: list[int], diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index d5329ef7b5ab..ce81a40ee3ae 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -3,10 +3,7 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: - if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs: - # Spec decode doesn't support top_p/top_k sampling. - return False - elif req_id in input_batch.min_p_reqs: + if req_id in input_batch.min_p_reqs: # Spec decode doesn't support min_p sampling. return False elif (req_id in input_batch.frequency_penalties_reqs