From 7c4f0588191795cf85f5068ee570a14da761ad3d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 14:58:50 -0700 Subject: [PATCH 01/11] [V1][Spec Decode] Enable spec decode for top-p & top-k sampling Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 76 ++++++++++++++++++++++++----- vllm/v1/spec_decode/utils.py | 5 +- 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 6284ae4b490a..f4b921d89b0e 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.utils import compiled_softmax from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -245,25 +246,74 @@ def compute_probs( return logits num_tokens = logits.shape[0] - batch_size = cu_num_draft_tokens.shape[0] - expanded_temperature = torch.empty( - (num_tokens, 1), - dtype=torch.float32, - device=logits.device, - ) - expand_kernel[(batch_size, )]( - expanded_temperature, + temperature = expand_batch_to_tokens( sampling_metadata.temperature, cu_num_draft_tokens, - GREEDY_TEMPERATURE, # replace_from - 1, # replace_to - MAX_NUM_TOKENS=MAX_SPEC_LEN, - num_warps=1, + num_tokens, + replace_from=GREEDY_TEMPERATURE, + replace_to=1, + ) + # TODO(woosuk): Consider using in-place op to reduce memory usage. + logits = logits / temperature.unsqueeze(-1) + top_p = expand_batch_to_tokens( + sampling_metadata.top_p, + cu_num_draft_tokens, + num_tokens, + ) + top_k = expand_batch_to_tokens( + sampling_metadata.top_k, + cu_num_draft_tokens, + num_tokens, ) - output_prob = compiled_softmax(logits, expanded_temperature) + logits = apply_top_k_top_p(logits, top_k, top_p) + output_prob = compiled_softmax(logits) return output_prob +def expand_batch_to_tokens( + x: Optional[torch.Tensor], # [batch_size] + cu_num_tokens: torch.Tensor, # [batch_size] + num_tokens: int, + replace_from: int = 0, + replace_to: int = 0, +) -> torch.Tensor: + """Expand [batch_size] tensor to [num_tokens] tensor based on the number of + tokens per batch in cu_num_tokens. + + Args: + x: [batch_size] tensor to expand. + cu_num_tokens: [batch_size] tensor containing the cumulative number of + tokens per batch. Each element represents the total number of + tokens up to and including that batch. + num_tokens: Total number of tokens. + replace_from: Value to replace with `replace_to` for tokens that are not + in the batch. + replace_to: Value to replace with `replace_from` for tokens that are in + the batch. + Returns: + expanded_x: [num_tokens] tensor. + """ + if x is None: + return None + batch_size = x.shape[0] + assert cu_num_tokens.shape[0] == batch_size + expanded_x = torch.empty( + (num_tokens, ), + dtype=x.dtype, + device=x.device, + ) + expand_kernel[(batch_size, )]( + expanded_x, + x, + cu_num_tokens, + replace_from, + replace_to, + MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + num_warps=1, + ) + return expanded_x + + def generate_uniform_probs( num_tokens: int, num_draft_tokens: list[int], diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index d5329ef7b5ab..ce81a40ee3ae 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -3,10 +3,7 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: - if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs: - # Spec decode doesn't support top_p/top_k sampling. - return False - elif req_id in input_batch.min_p_reqs: + if req_id in input_batch.min_p_reqs: # Spec decode doesn't support min_p sampling. return False elif (req_id in input_batch.frequency_penalties_reqs From 4ce33c9e88d2c1d07541e073d46bc2eaec3f9ffb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 15:00:33 -0700 Subject: [PATCH 02/11] comment Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index f4b921d89b0e..59801a517ec4 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -265,6 +265,8 @@ def compute_probs( cu_num_draft_tokens, num_tokens, ) + # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, + # which is slow for large vocab sizes. This may cause performance issues. logits = apply_top_k_top_p(logits, top_k, top_p) output_prob = compiled_softmax(logits) return output_prob From e1d647e4b707ff2e119e32fa5251e67d04a46cc6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 15:04:50 -0700 Subject: [PATCH 03/11] optional Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 59801a517ec4..9cf16ae0c85d 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -253,8 +253,13 @@ def compute_probs( replace_from=GREEDY_TEMPERATURE, replace_to=1, ) + # Because `sampling_metadata.all_greedy` is False, `temperature` is not + # None. This is to satisfy the type checker. + assert temperature is not None + # TODO(woosuk): Consider using in-place op to reduce memory usage. logits = logits / temperature.unsqueeze(-1) + top_p = expand_batch_to_tokens( sampling_metadata.top_p, cu_num_draft_tokens, @@ -278,7 +283,7 @@ def expand_batch_to_tokens( num_tokens: int, replace_from: int = 0, replace_to: int = 0, -) -> torch.Tensor: +) -> Optional[torch.Tensor]: """Expand [batch_size] tensor to [num_tokens] tensor based on the number of tokens per batch in cu_num_tokens. From b1aeb04aa52b457c98b5979e9d40b588bc43fdb1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 15:10:56 -0700 Subject: [PATCH 04/11] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 39 ++++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 9cf16ae0c85d..9fdd36c6eb6a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -253,37 +253,42 @@ def compute_probs( replace_from=GREEDY_TEMPERATURE, replace_to=1, ) - # Because `sampling_metadata.all_greedy` is False, `temperature` is not - # None. This is to satisfy the type checker. - assert temperature is not None - # TODO(woosuk): Consider using in-place op to reduce memory usage. logits = logits / temperature.unsqueeze(-1) - top_p = expand_batch_to_tokens( - sampling_metadata.top_p, - cu_num_draft_tokens, - num_tokens, - ) - top_k = expand_batch_to_tokens( - sampling_metadata.top_k, - cu_num_draft_tokens, - num_tokens, - ) + # Get expanded top_p and top_k tensors. + if sampling_metadata.top_p is not None: + top_p = expand_batch_to_tokens( + sampling_metadata.top_p, + cu_num_draft_tokens, + num_tokens, + ) + else: + top_p = None + if sampling_metadata.top_k is not None: + top_k = expand_batch_to_tokens( + sampling_metadata.top_k, + cu_num_draft_tokens, + num_tokens, + ) + else: + top_k = None + # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. logits = apply_top_k_top_p(logits, top_k, top_p) + output_prob = compiled_softmax(logits) return output_prob def expand_batch_to_tokens( - x: Optional[torch.Tensor], # [batch_size] + x: torch.Tensor, # [batch_size] cu_num_tokens: torch.Tensor, # [batch_size] num_tokens: int, replace_from: int = 0, replace_to: int = 0, -) -> Optional[torch.Tensor]: +) -> torch.Tensor: """Expand [batch_size] tensor to [num_tokens] tensor based on the number of tokens per batch in cu_num_tokens. @@ -300,8 +305,6 @@ def expand_batch_to_tokens( Returns: expanded_x: [num_tokens] tensor. """ - if x is None: - return None batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size expanded_x = torch.empty( From b1416a77c8ea39417a3bb839543687106dabaea6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 15:11:55 -0700 Subject: [PATCH 05/11] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 9fdd36c6eb6a..0f8ae8c6d298 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -256,23 +256,21 @@ def compute_probs( # TODO(woosuk): Consider using in-place op to reduce memory usage. logits = logits / temperature.unsqueeze(-1) - # Get expanded top_p and top_k tensors. - if sampling_metadata.top_p is not None: - top_p = expand_batch_to_tokens( - sampling_metadata.top_p, - cu_num_draft_tokens, - num_tokens, - ) - else: - top_p = None + # Get expanded top_k and top_p tensors. + top_k = None if sampling_metadata.top_k is not None: top_k = expand_batch_to_tokens( sampling_metadata.top_k, cu_num_draft_tokens, num_tokens, ) - else: - top_k = None + top_p = None + if sampling_metadata.top_p is not None: + top_p = expand_batch_to_tokens( + sampling_metadata.top_p, + cu_num_draft_tokens, + num_tokens, + ) # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. From 438e39e112a3e118fb061aeb612ab61d605f2e2a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 15:16:14 -0700 Subject: [PATCH 06/11] fix docstring Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 0f8ae8c6d298..e93d225316c0 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -296,10 +296,10 @@ def expand_batch_to_tokens( tokens per batch. Each element represents the total number of tokens up to and including that batch. num_tokens: Total number of tokens. - replace_from: Value to replace with `replace_to` for tokens that are not - in the batch. - replace_to: Value to replace with `replace_from` for tokens that are in - the batch. + replace_from: int = 0 + Value to be replaced if it is found in x. + replace_to: int = 0 + Value to replace with when replace_from is found. Returns: expanded_x: [num_tokens] tensor. """ From 0097a7bbf1d91c71f03e35dc6c6fe978b88181f5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 15:19:09 -0700 Subject: [PATCH 07/11] Add example Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index e93d225316c0..48322725464d 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -290,6 +290,9 @@ def expand_batch_to_tokens( """Expand [batch_size] tensor to [num_tokens] tensor based on the number of tokens per batch in cu_num_tokens. + For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then + num_tokens = 6, and expanded_x = [a, a, b, b, b, c]. + Args: x: [batch_size] tensor to expand. cu_num_tokens: [batch_size] tensor containing the cumulative number of From 5439c60e46a516e339d7b77b6afd6a56f6658216 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Mar 2025 13:39:55 -0700 Subject: [PATCH 08/11] Add tests Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 139 +++++++++++++++++++++- 1 file changed, 137 insertions(+), 2 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 8c423e367ef5..aa19a4bd6963 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -36,6 +36,8 @@ def create_logits_tensor(output_token_ids: list[list[int]], def create_sampling_metadata( all_greedy: bool, temperature: Optional[torch.Tensor] = None, + top_k: Optional[torch.Tensor] = None, + top_p: Optional[torch.Tensor] = None, generators: Optional[dict[int, Any]] = None, ) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set @@ -52,8 +54,8 @@ def create_sampling_metadata( temperature=temperature, all_greedy=all_greedy, all_random=not all_greedy, - top_p=None, - top_k=None, + top_p=top_p, + top_k=top_k, min_p=torch.empty(1, ), generators=generators, max_num_logprobs=0, @@ -462,3 +464,136 @@ def estimate_rejection_sampling_pdf( density=True) return hist.hist + + +def _test_masked_logits( + rejection_sampler, + batch_size: int, + num_draft_tokens: int, + vocab_size: int, + unmasked_indices: torch.Tensor, + sampling_metadata: SamplingMetadata, +): + # Set up test parameters + num_tokens = batch_size * num_draft_tokens + + # Create random draft probabilities. + draft_probs = torch.rand((num_tokens, vocab_size), + dtype=torch.float32, + device=DEVICE) + draft_probs = F.softmax(draft_probs, dim=-1) + + # Create logits with a clear distinction between masked and unmasked values + target_logits = torch.full((num_tokens, vocab_size), + -float("inf"), + device=DEVICE) + + # Set high logits for unmasked indices + for i in range(num_tokens): + target_logits[i, unmasked_indices[i]] = 100.0 + + # Randomly sample draft token ids from draft probs + draft_token_ids = torch.multinomial(draft_probs, num_samples=1) + draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens) + draft_token_ids = draft_token_ids.tolist() + + # Bonus tokens not used but required + bonus_token_ids = torch.zeros((batch_size, 1), + dtype=torch.int64, + device=DEVICE) + + # Create spec decode metadata + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids, + device=DEVICE, + ) + + # Run rejection sampling + output_token_ids = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) + + # Remove bonus tokens and reshape + output_token_ids = output_token_ids[:, :-1].flatten().tolist() + + # Check that all sampled tokens are within the unmasked indices. + for i in range(num_tokens): + token_id = output_token_ids[i] + if token_id == PLACEHOLDER_TOKEN_ID: + continue + assert token_id in unmasked_indices[i] + + +@pytest.mark.parametrize("top_k", [1, 5, 99]) +def test_top_k(rejection_sampler, top_k): + """Test rejection sampling with top-k sampling""" + vocab_size = 100 + batch_size = 10 + num_draft_tokens = 3 + num_tokens = batch_size * num_draft_tokens + + # Randomly create unmasked indices. + unmasked_indices = [ + torch.randperm(vocab_size, device=DEVICE)[:top_k] + for _ in range(num_tokens) + ] + unmasked_indices = torch.stack(unmasked_indices) + + # Create sampling metadata + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata( + all_greedy=False, + temperature=temperature, + top_k=torch.tensor([top_k] * batch_size, + device=DEVICE, + dtype=torch.int64), + ) + + _test_masked_logits( + rejection_sampler, + batch_size=batch_size, + num_draft_tokens=num_draft_tokens, + vocab_size=vocab_size, + unmasked_indices=unmasked_indices, + sampling_metadata=sampling_metadata, + ) + + +@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99]) +def test_top_p(rejection_sampler, top_p): + """Test rejection sampling with top-p sampling""" + vocab_size = 100 + batch_size = 10 + num_draft_tokens = 3 + num_tokens = batch_size * num_draft_tokens + + # Randomly create unmasked indices. + num_top_p_tokens = int(vocab_size * top_p) + unmasked_indices = [ + torch.randperm(vocab_size, device=DEVICE)[:num_top_p_tokens] + for _ in range(num_tokens) + ] + unmasked_indices = torch.stack(unmasked_indices) + + # Create sampling metadata + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata( + all_greedy=False, + temperature=temperature, + top_p=torch.tensor([top_p] * batch_size, + device=DEVICE, + dtype=torch.float32), + ) + + _test_masked_logits( + rejection_sampler, + batch_size=batch_size, + num_draft_tokens=num_draft_tokens, + vocab_size=vocab_size, + unmasked_indices=unmasked_indices, + sampling_metadata=sampling_metadata, + ) From 2d101406087a726e19e2a18793db4f0b93f958f7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Mar 2025 13:45:19 -0700 Subject: [PATCH 09/11] improve Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index aa19a4bd6963..6472df171edb 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -483,14 +483,14 @@ def _test_masked_logits( device=DEVICE) draft_probs = F.softmax(draft_probs, dim=-1) - # Create logits with a clear distinction between masked and unmasked values - target_logits = torch.full((num_tokens, vocab_size), - -float("inf"), - device=DEVICE) + # Create logits with the uniform distribution. + target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE) - # Set high logits for unmasked indices + # Increment the logits for unmasked indices, a little bit more than the + # masked ones. If the masking is effective, the masked indices will never + # be sampled despite the small difference in logits. for i in range(num_tokens): - target_logits[i, unmasked_indices[i]] = 100.0 + target_logits[i, unmasked_indices[i]] += 0.1 # Randomly sample draft token ids from draft probs draft_token_ids = torch.multinomial(draft_probs, num_samples=1) @@ -532,7 +532,7 @@ def _test_masked_logits( def test_top_k(rejection_sampler, top_k): """Test rejection sampling with top-k sampling""" vocab_size = 100 - batch_size = 10 + batch_size = 100 num_draft_tokens = 3 num_tokens = batch_size * num_draft_tokens @@ -567,7 +567,7 @@ def test_top_k(rejection_sampler, top_k): def test_top_p(rejection_sampler, top_p): """Test rejection sampling with top-p sampling""" vocab_size = 100 - batch_size = 10 + batch_size = 100 num_draft_tokens = 3 num_tokens = batch_size * num_draft_tokens From e66fab220a7a37667e470828760178e1cdbd9439 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Mar 2025 15:35:01 -0700 Subject: [PATCH 10/11] fix top-p test Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 55 ++++++++++++++--------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 6472df171edb..cbdb0b910d1d 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -471,6 +471,7 @@ def _test_masked_logits( batch_size: int, num_draft_tokens: int, vocab_size: int, + target_logits: torch.Tensor, unmasked_indices: torch.Tensor, sampling_metadata: SamplingMetadata, ): @@ -483,15 +484,6 @@ def _test_masked_logits( device=DEVICE) draft_probs = F.softmax(draft_probs, dim=-1) - # Create logits with the uniform distribution. - target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE) - - # Increment the logits for unmasked indices, a little bit more than the - # masked ones. If the masking is effective, the masked indices will never - # be sampled despite the small difference in logits. - for i in range(num_tokens): - target_logits[i, unmasked_indices[i]] += 0.1 - # Randomly sample draft token ids from draft probs draft_token_ids = torch.multinomial(draft_probs, num_samples=1) draft_token_ids = draft_token_ids.reshape(batch_size, num_draft_tokens) @@ -536,12 +528,21 @@ def test_top_k(rejection_sampler, top_k): num_draft_tokens = 3 num_tokens = batch_size * num_draft_tokens - # Randomly create unmasked indices. - unmasked_indices = [ + # Randomly create top-k indices. + top_k_indices = [ torch.randperm(vocab_size, device=DEVICE)[:top_k] for _ in range(num_tokens) ] - unmasked_indices = torch.stack(unmasked_indices) + top_k_indices = torch.stack(top_k_indices) + + # Create logits with the uniform distribution. + target_logits = torch.zeros((num_tokens, vocab_size), device=DEVICE) + + # Increment the logits for top-k indices, a little bit more than the other + # ones. If the masking is effective, the non-topk indices will never be + # sampled despite the small difference in logits. + for i in range(num_tokens): + target_logits[i, top_k_indices[i]] += 0.1 # Create sampling metadata temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) @@ -558,7 +559,8 @@ def test_top_k(rejection_sampler, top_k): batch_size=batch_size, num_draft_tokens=num_draft_tokens, vocab_size=vocab_size, - unmasked_indices=unmasked_indices, + target_logits=target_logits, + unmasked_indices=top_k_indices, sampling_metadata=sampling_metadata, ) @@ -571,16 +573,24 @@ def test_top_p(rejection_sampler, top_p): num_draft_tokens = 3 num_tokens = batch_size * num_draft_tokens - # Randomly create unmasked indices. - num_top_p_tokens = int(vocab_size * top_p) - unmasked_indices = [ - torch.randperm(vocab_size, device=DEVICE)[:num_top_p_tokens] - for _ in range(num_tokens) - ] - unmasked_indices = torch.stack(unmasked_indices) + # Create logits with the uniform distribution. + target_logits = torch.randn((num_tokens, vocab_size), device=DEVICE) + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + rescaled_logits = target_logits / temperature + + logits_sort, logits_idx = rescaled_logits.sort(dim=-1, descending=False) + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - top_p + # at least one + top_p_mask[:, -1] = False + + # Get the top-p indices. + top_p_indices = [] + for i in range(num_tokens): + top_p_indices.append(logits_idx[i][~top_p_mask[i]].tolist()) # Create sampling metadata - temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) sampling_metadata = create_sampling_metadata( all_greedy=False, temperature=temperature, @@ -594,6 +604,7 @@ def test_top_p(rejection_sampler, top_p): batch_size=batch_size, num_draft_tokens=num_draft_tokens, vocab_size=vocab_size, - unmasked_indices=unmasked_indices, + target_logits=target_logits, + unmasked_indices=top_p_indices, sampling_metadata=sampling_metadata, ) From 5dfb42e092b852ce77fd18e8409483637a37c9db Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Mar 2025 16:11:37 -0700 Subject: [PATCH 11/11] comment Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 48322725464d..c8327f36a585 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -308,11 +308,7 @@ def expand_batch_to_tokens( """ batch_size = x.shape[0] assert cu_num_tokens.shape[0] == batch_size - expanded_x = torch.empty( - (num_tokens, ), - dtype=x.dtype, - device=x.device, - ) + expanded_x = x.new_empty(num_tokens) expand_kernel[(batch_size, )]( expanded_x, x,