From c09ae5ebafeacd273426814fa6fbdd10a7a6ff9c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 14 Mar 2025 20:41:28 -0700 Subject: [PATCH 01/48] tmp Signed-off-by: Woosuk Kwon --- vllm/v1/sample/metadata.py | 2 +- vllm/v1/sample/rejection_sampler.py | 427 ++++++++++++++++------------ vllm/v1/worker/gpu_input_batch.py | 10 +- vllm/v1/worker/gpu_model_runner.py | 5 +- 4 files changed, 262 insertions(+), 182 deletions(-) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index e97e1235fb36..7e339e2a597d 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -9,7 +9,7 @@ @dataclass class SamplingMetadata: - temperature: Optional[torch.Tensor] + temperature: torch.Tensor all_greedy: bool all_random: bool diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index ea7f3353c115..a3a4a5dfc0d0 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -1,201 +1,280 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Optional import torch import torch.nn as nn -from torch.nn.utils.rnn import pad_sequence +import triton +import triton.language as tl -from vllm import envs from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.v1.outputs import SamplerOutput +from vllm.utils import is_pin_memory_available from vllm.v1.sample.metadata import SamplingMetadata -try: - import flashinfer.sampling as fs - is_flashinfer_available = True -except ImportError: - is_flashinfer_available = False - logger = init_logger(__name__) -INVALID_TOKEN_ID = -1 +PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 class RejectionSampler(nn.Module): - def __init__(self): + def __init__(self, max_num_tokens: int = 16 * 1024): super().__init__() - if current_platform.is_cuda(): - if is_flashinfer_available: - if envs.VLLM_USE_FLASHINFER_SAMPLER is not False: - # FIXME(woosuk): Currently, we have errors when using - # FlashInfer for rejection sampling. As a workaround, we - # disable FlashInfer for rejection sampling by default. - logger.info("Currently, FlashInfer rejection sampler is " - "disabled because of a bug. Falling back to " - "the PyTorch-native implementation of " - "rejection sampling.") - self.forward_method = self.forward_native - - # NOTE(woosuk): The V0 sampler doesn't use FlashInfer for - # sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by - # default it is unused). For backward compatibility, we set - # `VLLM_USE_FLASHINFER_SAMPLER` as None by default and - # interpret it differently in V0 and V1 samplers: In V0, - # None means False, while in V1, None means True. This is - # why we use the condition - # `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here. - # logger.info("Using FlashInfer for rejection sampling.") - # self.forward_method = self.flashinfer_sample - else: - logger.warning( - "FlashInfer is available, but it is not enabled. " - "Falling back to the PyTorch-native implementation of " - "rejection sampling. For the best performance, " - "please set VLLM_USE_FLASHINFER_SAMPLER=1.") - self.forward_method = self.forward_native - else: - logger.warning( - "FlashInfer is not available. Falling back to the PyTorch-" - "native implementation of rejection sampling. For the " - "best performance, please install FlashInfer.") - self.forward_method = self.forward_native - else: - self.forward_method = self.forward_native - - def forward(self, draft_token_ids: list[list[int]], - target_probs: torch.Tensor, - sampling_metadata: SamplingMetadata) -> SamplerOutput: - if not sampling_metadata.all_greedy: - raise NotImplementedError( - "Currently, only greedy sampling is supported by " - "rejection sampler.") - return self.forward_method(draft_token_ids, target_probs, - sampling_metadata) - - def flashinfer_sample( + self.buffer = torch.empty( + max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=is_pin_memory_available(), + ) + self.buffer_np = self.buffer.numpy() + + def forward( self, draft_token_ids: list[list[int]], + # [batch_size, max_spec_len + 1, vocab_size] target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: - # NOTE: The following input preparationg can be moved - # to the model runner with a persistent manner for better - # performance. - sample_lens = [len(x) + 1 for x in draft_token_ids] - # Convert draft token IDs to a tensor, split by sample_lens, then pad. - draft_token_ids = [ - torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids - ] - draft_token_ids_tensor = pad_sequence(draft_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) + ) -> torch.Tensor: + num_draft_tokens = [len(ids) for ids in draft_token_ids] + batch_size = len(draft_token_ids) + max_spec_len = max(num_draft_tokens) + draft_token_ids_np = self.buffer_np[:batch_size * max_spec_len] + for i, token_ids in enumerate(draft_token_ids): + start = i * max_spec_len + end = start + len(token_ids) + draft_token_ids_np[start:end] = token_ids - if sampling_metadata.all_greedy: - target_token_ids = target_probs.argmax(dim=-1).view(-1) - target_token_ids = target_token_ids.split(sample_lens) - target_token_ids = pad_sequence(target_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - - vocab_size = target_probs.size(-1) - # NOTE: CPU <-> GPU synchronization happens here. - draft_token_ids_tensor = draft_token_ids_tensor.to( - target_probs.device) - draft_probs = _create_greedy_token_probs(draft_token_ids_tensor, - vocab_size, - target_probs.device) - target_probs = _create_greedy_token_probs(target_token_ids, - vocab_size, - target_probs.device) - uniform_samples = torch.zeros(draft_token_ids_tensor.size(0), - draft_token_ids_tensor.size(1) + 1, - device=target_probs.device) - else: - raise NotImplementedError( - "Currently, only greedy sampling is supported by " - "rejection sampler.") - - sampled_token_ids, _, _ = fs.chain_speculative_sampling( - draft_probs, - draft_token_ids_tensor, - uniform_samples, + draft_token_ids_cpu = self.buffer[:batch_size * max_spec_len] + draft_token_ids_cpu = draft_token_ids_cpu.view(batch_size, + max_spec_len) + draft_token_ids = draft_token_ids_cpu.to(device=target_probs.device, + non_blocking=True) + output_token_ids = rejection_sample( + draft_token_ids, + num_draft_tokens, + None, # draft_probs target_probs, + None, # bonus_token_ids + sampling_metadata, ) - return SamplerOutput(sampled_token_ids=sampled_token_ids, - logprobs_tensors=None) + return output_token_ids - # TODO: The following method can be optimized for better performance. - def forward_native( - self, - draft_token_ids: list[list[int]], - target_probs: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> SamplerOutput: - sample_lens = [len(x) + 1 for x in draft_token_ids] - # Convert draft token IDs to a tensor, split by sample_lens, then pad. - draft_token_ids = [ - torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids - ] - draft_token_ids_tensor = pad_sequence(draft_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device) - # Add 1 to include the 'bonus' token. - if sampling_metadata.all_greedy: - output_token_ids = target_probs.argmax(dim=-1).view(-1) - output_token_ids = output_token_ids.split(sample_lens) - output_token_ids = pad_sequence(output_token_ids, - batch_first=True, - padding_value=INVALID_TOKEN_ID) - # Produce a mask that remains 1 (True) until the first - # mismatch (cumprod turns 0 after a mismatch). - accept_mask = ( - output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod( - dim=1) - else: - raise NotImplementedError( - "Currently, only greedy sampling is supported by " - "rejection sampler.") - # Identify valid positions (non-padding). - valid_mask = output_token_ids != INVALID_TOKEN_ID - # Generate mask with bonus token. - generate_mask = torch.cat([ - accept_mask, - torch.zeros(accept_mask.size(0), 1, device=accept_mask.device) - ], - dim=1).to(torch.bool) & valid_mask - zeros_mask = (generate_mask == 0) - first_zero_idx = zeros_mask.float().argmax(dim=1) - # Figure out which rows actually contain at least one zero. - rows_with_zero = zeros_mask.any(dim=1) - # Use indexing to set the first zero in each of those rows to 1. - generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1 - - output_token_ids[~generate_mask] = INVALID_TOKEN_ID - return SamplerOutput(sampled_token_ids=output_token_ids, - logprobs_tensors=None) - - -def _create_greedy_token_probs( - token_ids: torch.Tensor, - vocab_size: int, - out_device: torch.device, + +def rejection_sample( + # [batch_size, max_spec_len] + draft_token_ids: torch.Tensor, + # [batch_size] + num_draft_tokens: list[int], + # [batch_size, max_spec_len, vocab_size] + draft_probs: Optional[torch.Tensor], + # [batch_size, max_spec_len + 1, vocab_size] + target_probs: torch.Tensor, + # [batch_size] + bonus_token_ids: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - batch_size, num_tokens = token_ids.shape + batch_size = draft_token_ids.shape[0] + max_spec_len = draft_token_ids.shape[1] + vocab_size = target_probs.shape[-1] + device = target_probs.device + assert draft_token_ids.is_contiguous() + assert draft_probs is None or draft_probs.is_contiguous() + assert target_probs.is_contiguous() + assert bonus_token_ids.is_contiguous() + + # Rejection sampling. + output_token_ids = torch.empty( + (batch_size, max_spec_len + 1), + dtype=torch.int64, + device=device, + ) + output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + is_greedy = sampling_metadata.temperature == -1 + if not sampling_metadata.all_random: + # Rejection sampling for greedy sampling requests. + target_argmax = target_probs.argmax(dim=-1) + rejection_greedy_sample_kernel[(batch_size, )]( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + max_spec_len, + ) + if sampling_metadata.all_greedy: + return output_token_ids + + # Generate uniform probabilities for rejection sampling. + uniform_probs = torch.rand( + (batch_size, max_spec_len), + dtype=torch.float32, + device=device, + ) + for i, generator in sampling_metadata.generators.items(): + num_tokens = num_draft_tokens[i] + if num_tokens > 0: + # NOTE(woosuk): We shouldn't use max_spec_len here because + # max_spec_len is affected by other requests in the batch. + uniform_probs[i][:num_tokens].uniform_(generator=generator) + + # Sample recovered tokens for each position. + # Compute the adjusted probabilities. + is_ngram = draft_probs is None + if is_ngram: + # [batch_size, max_spec_len, vocab_size] + probs = target_probs[:, :-1].clone() + # [batch_size, max_spec_len] + safe_draft_token_ids = torch.where( + draft_token_ids == PLACEHOLDER_TOKEN_ID, 0, draft_token_ids) + # Set probabilities to 0 for draft token positions + probs.scatter_(2, safe_draft_token_ids.unsqueeze(-1), 0) + else: + probs = torch.clamp(target_probs[:, :-1] - draft_probs, min=1e-8) + probs /= probs.sum(dim=-1, keepdim=True) + + # NOTE(woosuk): Create only one distribution for each request. + q = torch.empty( + (batch_size, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + for i, generator in sampling_metadata.generators.items(): + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) + q = q.unsqueeze(dim=1) + recovered_token_ids = probs.div_(q).argmax(dim=-1) + recovered_token_ids = recovered_token_ids.view(batch_size, max_spec_len) + + # Rejection sampling for random sampling requests. + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + is_ngram, + ) + return output_token_ids + + +@triton.jit +def rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + draft_token_ids_ptr, # [batch_size, max_spec_len] + draft_probs_ptr, # [batch_size, max_spec_len, vocab_size] or None + target_probs_ptr, # [batch_size, max_spec_len + 1, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [batch_size, max_spec_len] + uniform_probs_ptr, # [batch_size, UNIFORM_PROBS_LEN] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + IS_NGRAM: tl.constexpr, +): + seq_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + seq_idx) + if is_greedy: + # Early exit for greedy sampling requests. + return - token_probs = torch.zeros(batch_size, - num_tokens, - vocab_size, - dtype=torch.float, - device=out_device) + rejected = False + finished = False + num_generated = 0 + for pos in range(max_spec_len): + if not finished: + token_id = tl.load(draft_token_ids_ptr + seq_idx * max_spec_len + + pos) + if token_id == PLACEHOLDER_TOKEN_ID: + finished = True + else: + if IS_NGRAM: + draft_prob = 1 + else: + # NOTE(woosuk): Here, we assume that draft_prob is nonzero. + draft_prob = tl.load(draft_probs_ptr + + seq_idx * max_spec_len * vocab_size + + pos * vocab_size + token_id) + target_prob = tl.load(target_probs_ptr + seq_idx * + (max_spec_len + 1) * vocab_size + + pos * vocab_size + token_id) + uniform_prob = tl.load(uniform_probs_ptr + + seq_idx * max_spec_len + pos) + if target_prob / draft_prob >= uniform_prob: + # Accept. + tl.store( + output_token_ids_ptr + seq_idx * (max_spec_len + 1) + + pos, token_id) + num_generated += 1 + else: + # Reject. Use recovered token. + rejected = True + recovered_token_id = tl.load(recovered_token_ids_ptr + + seq_idx * max_spec_len + pos) + tl.store( + output_token_ids_ptr + seq_idx * (max_spec_len + 1) + + pos, recovered_token_id) + num_generated += 1 + finished = True + + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + seq_idx) + tl.store( + output_token_ids_ptr + seq_idx * (max_spec_len + 1) + + num_generated, bonus_token_id) - # Ignore INVALID_TOKEN_ID. - valid_mask = (token_ids != INVALID_TOKEN_ID) - valid_indices = token_ids.clone() - valid_indices[~valid_mask] = 0 - token_probs.scatter_(dim=2, - index=valid_indices.unsqueeze(-1), - src=valid_mask.unsqueeze(-1).float()) +@triton.jit +def rejection_greedy_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + draft_token_ids_ptr, # [batch_size, max_spec_len] + target_argmax_ptr, # [batch_size, max_spec_len + 1] + bonus_token_ids_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] + max_spec_len, +): + seq_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + seq_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests. + return + + rejected = False + finished = False + num_generated = 0 + for pos in range(max_spec_len): + if not finished: + token_id = tl.load(draft_token_ids_ptr + seq_idx * max_spec_len + + pos) + if token_id == PLACEHOLDER_TOKEN_ID: + finished = True + else: + draft_token_id = tl.load(draft_token_ids_ptr + + seq_idx * max_spec_len + pos) + target_argmax = tl.load(target_argmax_ptr + seq_idx * + (max_spec_len + 1) + pos) + if draft_token_id == target_argmax: + # Accept. + tl.store( + output_token_ids_ptr + seq_idx * (max_spec_len + 1) + + pos, draft_token_id) + num_generated += 1 + else: + # Reject. + rejected = True + tl.store( + output_token_ids_ptr + seq_idx * (max_spec_len + 1) + + pos, target_argmax) + num_generated += 1 + finished = True - return token_probs + if not rejected: + # If all tokens are accepted, append the bonus token. + bonus_token_id = tl.load(bonus_token_ids_ptr + seq_idx) + tl.store( + output_token_ids_ptr + seq_idx * (max_spec_len + 1) + + num_generated, bonus_token_id) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 9707cb5774cd..971d5d31f9da 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -524,11 +524,11 @@ def refresh_sampling_metadata(self): def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs - if not self.all_greedy: - temperature = copy_slice(self.temperature_cpu_tensor, - self.temperature, num_reqs) - else: - temperature = None + # NOTE(woosuk): Even if all requests are greedy, we copy the + # temperature tensor for simplicity. The temperature tensor is used + # for speculative decoding. + temperature = copy_slice(self.temperature_cpu_tensor, self.temperature, + num_reqs) if not self.no_top_p: copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) if not self.no_top_k: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c2a976108e4d..e8075035bf62 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,7 +35,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler +from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, + RejectionSampler) from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -1057,7 +1058,7 @@ def execute_model( valid_sampled_token_ids = sampled_token_ids.tolist() else: # Includes spec decode tokens. - valid_mask = sampled_token_ids != INVALID_TOKEN_ID + valid_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID gen_lens = valid_mask.sum(dim=1).tolist() # TODO(woosuk): Optimize this. valid_sampled_token_ids = [ From e3f351363aa775e63e8e441c6e72b255e6275d8f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 09:05:56 -0700 Subject: [PATCH 02/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index a3a4a5dfc0d0..078410fd8c1f 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -127,7 +127,8 @@ def rejection_sample( # Set probabilities to 0 for draft token positions probs.scatter_(2, safe_draft_token_ids.unsqueeze(-1), 0) else: - probs = torch.clamp(target_probs[:, :-1] - draft_probs, min=1e-8) + probs = torch.clamp(target_probs[:, :-1] - draft_probs, + min=torch.finfo(torch.float32).tiny) probs /= probs.sum(dim=-1, keepdim=True) # NOTE(woosuk): Create only one distribution for each request. From be535aa887f90142a57c89fb0d50fec462969eca Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 09:39:00 -0700 Subject: [PATCH 03/48] fix shape Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 47 +++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 078410fd8c1f..5deef5d79096 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -16,8 +16,9 @@ class RejectionSampler(nn.Module): - def __init__(self, max_num_tokens: int = 16 * 1024): + def __init__(self, max_num_tokens: int = 32 * 1024): super().__init__() + self.max_num_tokens = max_num_tokens self.buffer = torch.empty( max_num_tokens, dtype=torch.int64, @@ -29,13 +30,33 @@ def __init__(self, max_num_tokens: int = 16 * 1024): def forward( self, draft_token_ids: list[list[int]], - # [batch_size, max_spec_len + 1, vocab_size] - target_probs: torch.Tensor, + # [batch_size, max_spec_len, vocab_size] + target_logits: torch.Tensor, + # [batch_size] + bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: num_draft_tokens = [len(ids) for ids in draft_token_ids] batch_size = len(draft_token_ids) max_spec_len = max(num_draft_tokens) + assert batch_size * max_spec_len <= self.max_num_tokens + + if sampling_metadata.all_greedy: + # [batch_size, max_spec_len, vocab_size] + target_probs = target_logits.contiguous() + else: + # [batch_size, max_spec_len, vocab_size] + target_probs = torch.softmax( + target_logits / sampling_metadata.temperature.unsqueeze(-1), + dim=-1, + dtype=torch.float32, + ) + target_probs = torch.where( + sampling_metadata.temperature == -1, + target_logits, + target_probs, + ) + draft_token_ids_np = self.buffer_np[:batch_size * max_spec_len] for i, token_ids in enumerate(draft_token_ids): start = i * max_spec_len @@ -52,7 +73,7 @@ def forward( num_draft_tokens, None, # draft_probs target_probs, - None, # bonus_token_ids + bonus_token_ids, sampling_metadata, ) return output_token_ids @@ -65,7 +86,7 @@ def rejection_sample( num_draft_tokens: list[int], # [batch_size, max_spec_len, vocab_size] draft_probs: Optional[torch.Tensor], - # [batch_size, max_spec_len + 1, vocab_size] + # [batch_size, max_spec_len, vocab_size] target_probs: torch.Tensor, # [batch_size] bonus_token_ids: torch.Tensor, @@ -120,14 +141,14 @@ def rejection_sample( is_ngram = draft_probs is None if is_ngram: # [batch_size, max_spec_len, vocab_size] - probs = target_probs[:, :-1].clone() + probs = target_probs.clone() # [batch_size, max_spec_len] safe_draft_token_ids = torch.where( draft_token_ids == PLACEHOLDER_TOKEN_ID, 0, draft_token_ids) # Set probabilities to 0 for draft token positions probs.scatter_(2, safe_draft_token_ids.unsqueeze(-1), 0) else: - probs = torch.clamp(target_probs[:, :-1] - draft_probs, + probs = torch.clamp(target_probs - draft_probs, min=torch.finfo(torch.float32).tiny) probs /= probs.sum(dim=-1, keepdim=True) @@ -167,7 +188,7 @@ def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] draft_token_ids_ptr, # [batch_size, max_spec_len] draft_probs_ptr, # [batch_size, max_spec_len, vocab_size] or None - target_probs_ptr, # [batch_size, max_spec_len + 1, vocab_size] + target_probs_ptr, # [batch_size, max_spec_len, vocab_size] bonus_token_ids_ptr, # [batch_size] recovered_token_ids_ptr, # [batch_size, max_spec_len] uniform_probs_ptr, # [batch_size, UNIFORM_PROBS_LEN] @@ -199,8 +220,8 @@ def rejection_random_sample_kernel( draft_prob = tl.load(draft_probs_ptr + seq_idx * max_spec_len * vocab_size + pos * vocab_size + token_id) - target_prob = tl.load(target_probs_ptr + seq_idx * - (max_spec_len + 1) * vocab_size + + target_prob = tl.load(target_probs_ptr + + seq_idx * max_spec_len * vocab_size + pos * vocab_size + token_id) uniform_prob = tl.load(uniform_probs_ptr + seq_idx * max_spec_len + pos) @@ -233,7 +254,7 @@ def rejection_random_sample_kernel( def rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] draft_token_ids_ptr, # [batch_size, max_spec_len] - target_argmax_ptr, # [batch_size, max_spec_len + 1] + target_argmax_ptr, # [batch_size, max_spec_len] bonus_token_ids_ptr, # [batch_size] is_greedy_ptr, # [batch_size] max_spec_len, @@ -256,8 +277,8 @@ def rejection_greedy_sample_kernel( else: draft_token_id = tl.load(draft_token_ids_ptr + seq_idx * max_spec_len + pos) - target_argmax = tl.load(target_argmax_ptr + seq_idx * - (max_spec_len + 1) + pos) + target_argmax = tl.load(target_argmax_ptr + + seq_idx * max_spec_len + pos) if draft_token_id == target_argmax: # Accept. tl.store( From be950c72bb21cf769f0291468fc7d9e08952a7a9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 12:57:35 -0700 Subject: [PATCH 04/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 128 +++++++++++++++++++++------- 1 file changed, 97 insertions(+), 31 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 5deef5d79096..a83137f4fa80 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -12,6 +12,7 @@ logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 +GREEDY_TEMPERATURE: tl.constexpr = -1 class RejectionSampler(nn.Module): @@ -30,53 +31,59 @@ def __init__(self, max_num_tokens: int = 32 * 1024): def forward( self, draft_token_ids: list[list[int]], - # [batch_size, max_spec_len, vocab_size] + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens, vocab_size] + draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] target_logits: torch.Tensor, # [batch_size] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - num_draft_tokens = [len(ids) for ids in draft_token_ids] + # [batch_size, max_spec_len] + draft_token_ids_tensor = self._async_copy_to_device( + draft_token_ids, + target_logits.device, + ) + max_spec_len = draft_token_ids_tensor.shape[1] + # [num_tokens, vocab_size] + target_probs = compute_probs( + target_logits, + sampling_metadata.temperature, + cu_num_draft_tokens, + max_spec_len, + ) + output_token_ids = rejection_sample( + draft_token_ids_tensor, + cu_num_draft_tokens, + draft_probs, + target_probs, + bonus_token_ids, + sampling_metadata, + ) + return output_token_ids + + def _async_copy_to_device( + self, + draft_token_ids: list[list[int]], + device: torch.device, + ) -> torch.Tensor: batch_size = len(draft_token_ids) + num_draft_tokens = [len(ids) for ids in draft_token_ids] max_spec_len = max(num_draft_tokens) assert batch_size * max_spec_len <= self.max_num_tokens - if sampling_metadata.all_greedy: - # [batch_size, max_spec_len, vocab_size] - target_probs = target_logits.contiguous() - else: - # [batch_size, max_spec_len, vocab_size] - target_probs = torch.softmax( - target_logits / sampling_metadata.temperature.unsqueeze(-1), - dim=-1, - dtype=torch.float32, - ) - target_probs = torch.where( - sampling_metadata.temperature == -1, - target_logits, - target_probs, - ) - draft_token_ids_np = self.buffer_np[:batch_size * max_spec_len] + draft_token_ids_np.fill(PLACEHOLDER_TOKEN_ID) for i, token_ids in enumerate(draft_token_ids): start = i * max_spec_len end = start + len(token_ids) draft_token_ids_np[start:end] = token_ids - draft_token_ids_cpu = self.buffer[:batch_size * max_spec_len] draft_token_ids_cpu = draft_token_ids_cpu.view(batch_size, max_spec_len) - draft_token_ids = draft_token_ids_cpu.to(device=target_probs.device, - non_blocking=True) - output_token_ids = rejection_sample( - draft_token_ids, - num_draft_tokens, - None, # draft_probs - target_probs, - bonus_token_ids, - sampling_metadata, - ) - return output_token_ids + return draft_token_ids_cpu.to(device=device, non_blocking=True) def rejection_sample( @@ -108,7 +115,7 @@ def rejection_sample( device=device, ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) - is_greedy = sampling_metadata.temperature == -1 + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) @@ -300,3 +307,62 @@ def rejection_greedy_sample_kernel( tl.store( output_token_ids_ptr + seq_idx * (max_spec_len + 1) + num_generated, bonus_token_id) + + +def compute_probs( + logits: torch.Tensor, # [num_tokens, vocab_size] + temperature: torch.Tensor, # [batch_size] + cu_num_draft_tokens: torch.Tensor, # [batch_size] + max_spec_len: int, +) -> torch.Tensor: + output_prob = torch.empty_like(logits, dtype=torch.float32) + batch_size = temperature.shape[0] + vocab_size = logits.shape[-1] + compute_probs_kernel[(batch_size, max_spec_len)]( + output_prob, + logits, + temperature, + cu_num_draft_tokens, + vocab_size, + triton.next_power_of_two(vocab_size), + ) + return output_prob + + +@triton.jit +def compute_probs_kernel( + output_prob_ptr, # [num_tokens, vocab_size] + logits_ptr, # [num_tokens, vocab_size] + temperature_ptr, # [batch_size] + cu_num_draft_tokens_ptr, # [batch_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + + # Early exit for out-of-range positions. + pos = tl.program_id(1) + if pos >= end_idx - start_idx: + return + + vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + logits = tl.load(logits_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size) + temperature = tl.load(temperature_ptr + req_idx) + if temperature == GREEDY_TEMPERATURE: + # Greedy sampling. Just return the logits. + output_prob = logits + else: + # Random sampling. + output_prob = tl.softmax(logits / temperature) + output_prob = output_prob.to(dtype=tl.float32) + + tl.store(output_prob_ptr + (start_idx + pos) * vocab_size + vocab_offset, + output_prob, + mask=vocab_offset < vocab_size) From 1fee1771477e07054de030aa461bdb0caf00fe82 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 13:01:03 -0700 Subject: [PATCH 05/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index a83137f4fa80..dacdc42ec789 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -41,12 +41,8 @@ def forward( bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - # [batch_size, max_spec_len] - draft_token_ids_tensor = self._async_copy_to_device( - draft_token_ids, - target_logits.device, - ) - max_spec_len = draft_token_ids_tensor.shape[1] + num_draft_tokens = [len(ids) for ids in draft_token_ids] + max_spec_len = max(num_draft_tokens) # [num_tokens, vocab_size] target_probs = compute_probs( target_logits, @@ -54,8 +50,14 @@ def forward( cu_num_draft_tokens, max_spec_len, ) + # [batch_size, max_spec_len] + draft_token_ids_tensor = self._async_copy_to_device( + draft_token_ids, + target_logits.device, + ) output_token_ids = rejection_sample( draft_token_ids_tensor, + num_draft_tokens, cu_num_draft_tokens, draft_probs, target_probs, @@ -91,6 +93,8 @@ def rejection_sample( draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], + # [batch_size] + cu_num_draft_tokens: torch.Tensor, # [batch_size, max_spec_len, vocab_size] draft_probs: Optional[torch.Tensor], # [batch_size, max_spec_len, vocab_size] From d30970e2454cf2f55c903c238a54b0f0b1febc91 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 13:06:44 -0700 Subject: [PATCH 06/48] Add parse_outputs Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index dacdc42ec789..52abe96a9b3c 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -87,6 +87,17 @@ def _async_copy_to_device( max_spec_len) return draft_token_ids_cpu.to(device=device, non_blocking=True) + @staticmethod + def parse_output(output_token_ids: torch.Tensor) -> list[list[int]]: + output_token_ids = output_token_ids.tolist() + outputs: list[list[int]] = [[] for _ in output_token_ids] + for i, token_ids in enumerate(output_token_ids): + for token_id in token_ids: + if token_id == PLACEHOLDER_TOKEN_ID: + break + outputs[i].append(token_id) + return outputs + def rejection_sample( # [batch_size, max_spec_len] From 32fefa10bfee7238f8e08a0f2191cb359f92db3b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 13:07:11 -0700 Subject: [PATCH 07/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 52abe96a9b3c..849f08b3049f 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -66,6 +66,17 @@ def forward( ) return output_token_ids + @staticmethod + def parse_output(output_token_ids: torch.Tensor) -> list[list[int]]: + output_token_ids = output_token_ids.tolist() + outputs: list[list[int]] = [[] for _ in output_token_ids] + for i, token_ids in enumerate(output_token_ids): + for token_id in token_ids: + if token_id == PLACEHOLDER_TOKEN_ID: + break + outputs[i].append(token_id) + return outputs + def _async_copy_to_device( self, draft_token_ids: list[list[int]], @@ -87,17 +98,6 @@ def _async_copy_to_device( max_spec_len) return draft_token_ids_cpu.to(device=device, non_blocking=True) - @staticmethod - def parse_output(output_token_ids: torch.Tensor) -> list[list[int]]: - output_token_ids = output_token_ids.tolist() - outputs: list[list[int]] = [[] for _ in output_token_ids] - for i, token_ids in enumerate(output_token_ids): - for token_id in token_ids: - if token_id == PLACEHOLDER_TOKEN_ID: - break - outputs[i].append(token_id) - return outputs - def rejection_sample( # [batch_size, max_spec_len] From 4a9397341982138b4328060007110289e962a00d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 13:11:07 -0700 Subject: [PATCH 08/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 849f08b3049f..02d0a0adcf0e 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -69,6 +69,7 @@ def forward( @staticmethod def parse_output(output_token_ids: torch.Tensor) -> list[list[int]]: output_token_ids = output_token_ids.tolist() + # Preallocate outputs. outputs: list[list[int]] = [[] for _ in output_token_ids] for i, token_ids in enumerate(output_token_ids): for token_id in token_ids: @@ -106,16 +107,15 @@ def rejection_sample( num_draft_tokens: list[int], # [batch_size] cu_num_draft_tokens: torch.Tensor, - # [batch_size, max_spec_len, vocab_size] + # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], - # [batch_size, max_spec_len, vocab_size] + # [num_tokens, vocab_size] target_probs: torch.Tensor, # [batch_size] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - batch_size = draft_token_ids.shape[0] - max_spec_len = draft_token_ids.shape[1] + batch_size, max_spec_len = draft_token_ids.shape vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() @@ -123,13 +123,14 @@ def rejection_sample( assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() - # Rejection sampling. + # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), dtype=torch.int64, device=device, ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. From f2455fde780ea6bf468033744289550ac128e8f4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 13:22:09 -0700 Subject: [PATCH 09/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 117 +++++++++++++++++++--------- 1 file changed, 79 insertions(+), 38 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 02d0a0adcf0e..30ac7880eb25 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -147,6 +147,71 @@ def rejection_sample( return output_token_ids # Generate uniform probabilities for rejection sampling. + uniform_probs = generate_uniform_probs( + batch_size, + max_spec_len, + num_draft_tokens, + sampling_metadata, + device, + ) + + # Sample recovered tokens for each position. + recovered_token_ids = sample_recovered_tokens( + batch_size, + max_spec_len, + vocab_size, + num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + sampling_metadata, + device, + ) + + # Rejection sampling for random sampling requests. + rejection_random_sample_kernel[(batch_size, )]( + output_token_ids, + draft_token_ids, + draft_probs, + target_probs, + bonus_token_ids, + recovered_token_ids, + uniform_probs, + is_greedy, + max_spec_len, + vocab_size, + IS_NGRAM=draft_probs is None, + ) + return output_token_ids + + +def compute_probs( + logits: torch.Tensor, # [num_tokens, vocab_size] + temperature: torch.Tensor, # [batch_size] + cu_num_draft_tokens: torch.Tensor, # [batch_size] + max_spec_len: int, +) -> torch.Tensor: + output_prob = torch.empty_like(logits, dtype=torch.float32) + batch_size = temperature.shape[0] + vocab_size = logits.shape[-1] + compute_probs_kernel[(batch_size, max_spec_len)]( + output_prob, + logits, + temperature, + cu_num_draft_tokens, + vocab_size, + triton.next_power_of_two(vocab_size), + ) + return output_prob + + +def generate_uniform_probs( + batch_size: int, + max_spec_len: int, + num_draft_tokens: list[int], + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: uniform_probs = torch.rand( (batch_size, max_spec_len), dtype=torch.float32, @@ -158,8 +223,20 @@ def rejection_sample( # NOTE(woosuk): We shouldn't use max_spec_len here because # max_spec_len is affected by other requests in the batch. uniform_probs[i][:num_tokens].uniform_(generator=generator) + return uniform_probs - # Sample recovered tokens for each position. + +def sample_recovered_tokens( + batch_size: int, + max_spec_len: int, + vocab_size: int, + num_draft_tokens: list[int], + draft_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor], + target_probs: torch.Tensor, + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: # Compute the adjusted probabilities. is_ngram = draft_probs is None if is_ngram: @@ -187,23 +264,7 @@ def rejection_sample( q[i].exponential_(generator=generator) q = q.unsqueeze(dim=1) recovered_token_ids = probs.div_(q).argmax(dim=-1) - recovered_token_ids = recovered_token_ids.view(batch_size, max_spec_len) - - # Rejection sampling for random sampling requests. - rejection_random_sample_kernel[(batch_size, )]( - output_token_ids, - draft_token_ids, - draft_probs, - target_probs, - bonus_token_ids, - recovered_token_ids, - uniform_probs, - is_greedy, - max_spec_len, - vocab_size, - is_ngram, - ) - return output_token_ids + return recovered_token_ids.view(batch_size, max_spec_len) @triton.jit @@ -325,26 +386,6 @@ def rejection_greedy_sample_kernel( num_generated, bonus_token_id) -def compute_probs( - logits: torch.Tensor, # [num_tokens, vocab_size] - temperature: torch.Tensor, # [batch_size] - cu_num_draft_tokens: torch.Tensor, # [batch_size] - max_spec_len: int, -) -> torch.Tensor: - output_prob = torch.empty_like(logits, dtype=torch.float32) - batch_size = temperature.shape[0] - vocab_size = logits.shape[-1] - compute_probs_kernel[(batch_size, max_spec_len)]( - output_prob, - logits, - temperature, - cu_num_draft_tokens, - vocab_size, - triton.next_power_of_two(vocab_size), - ) - return output_prob - - @triton.jit def compute_probs_kernel( output_prob_ptr, # [num_tokens, vocab_size] From fbba0fff589efa12d4edc68e248e0cb03a1f6346 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 14:37:57 -0700 Subject: [PATCH 10/48] kernel Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 294 ++++++++++++++++------------ 1 file changed, 171 insertions(+), 123 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 30ac7880eb25..7d18659dee1a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -13,6 +13,7 @@ logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 +TINY: tl.constexpr = 1.1754943508222875e-38 # torch.finfo(torch.float32).tiny class RejectionSampler(nn.Module): @@ -101,7 +102,7 @@ def _async_copy_to_device( def rejection_sample( - # [batch_size, max_spec_len] + # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], @@ -115,7 +116,9 @@ def rejection_sample( bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - batch_size, max_spec_len = draft_token_ids.shape + batch_size = len(num_draft_tokens) + max_spec_len = max(num_draft_tokens) + num_tokens = sum(num_draft_tokens) vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() @@ -137,6 +140,7 @@ def rejection_sample( target_argmax = target_probs.argmax(dim=-1) rejection_greedy_sample_kernel[(batch_size, )]( output_token_ids, + cu_num_draft_tokens, draft_token_ids, target_argmax, bonus_token_ids, @@ -147,20 +151,20 @@ def rejection_sample( return output_token_ids # Generate uniform probabilities for rejection sampling. + # [num_tokens] uniform_probs = generate_uniform_probs( - batch_size, - max_spec_len, + num_tokens, num_draft_tokens, sampling_metadata, device, ) # Sample recovered tokens for each position. + # [num_tokens] recovered_token_ids = sample_recovered_tokens( - batch_size, max_spec_len, - vocab_size, num_draft_tokens, + cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, @@ -171,6 +175,7 @@ def rejection_sample( # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, + cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, @@ -206,53 +211,44 @@ def compute_probs( def generate_uniform_probs( - batch_size: int, - max_spec_len: int, + num_tokens: int, num_draft_tokens: list[int], sampling_metadata: SamplingMetadata, device: torch.device, ) -> torch.Tensor: uniform_probs = torch.rand( - (batch_size, max_spec_len), + (num_tokens, ), dtype=torch.float32, device=device, ) - for i, generator in sampling_metadata.generators.items(): - num_tokens = num_draft_tokens[i] - if num_tokens > 0: - # NOTE(woosuk): We shouldn't use max_spec_len here because - # max_spec_len is affected by other requests in the batch. - uniform_probs[i][:num_tokens].uniform_(generator=generator) + start_idx = 0 + for req_idx, n in enumerate(num_draft_tokens): + if n == 0: + continue + end_idx = start_idx + n + generator = sampling_metadata.generators[req_idx] + uniform_probs[start_idx:end_idx].uniform_(generator=generator) + start_idx = end_idx return uniform_probs def sample_recovered_tokens( - batch_size: int, max_spec_len: int, - vocab_size: int, num_draft_tokens: list[int], + # [batch_size] + cu_num_draft_tokens: torch.Tensor, + # [num_tokens] draft_token_ids: torch.Tensor, + # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], + # [num_tokens, vocab_size] target_probs: torch.Tensor, sampling_metadata: SamplingMetadata, device: torch.device, ) -> torch.Tensor: - # Compute the adjusted probabilities. - is_ngram = draft_probs is None - if is_ngram: - # [batch_size, max_spec_len, vocab_size] - probs = target_probs.clone() - # [batch_size, max_spec_len] - safe_draft_token_ids = torch.where( - draft_token_ids == PLACEHOLDER_TOKEN_ID, 0, draft_token_ids) - # Set probabilities to 0 for draft token positions - probs.scatter_(2, safe_draft_token_ids.unsqueeze(-1), 0) - else: - probs = torch.clamp(target_probs - draft_probs, - min=torch.finfo(torch.float32).tiny) - probs /= probs.sum(dim=-1, keepdim=True) - # NOTE(woosuk): Create only one distribution for each request. + batch_size = len(num_draft_tokens) + vocab_size = target_probs.shape[-1] q = torch.empty( (batch_size, vocab_size), dtype=torch.float32, @@ -262,128 +258,122 @@ def sample_recovered_tokens( for i, generator in sampling_metadata.generators.items(): if num_draft_tokens[i] > 0: q[i].exponential_(generator=generator) - q = q.unsqueeze(dim=1) - recovered_token_ids = probs.div_(q).argmax(dim=-1) - return recovered_token_ids.view(batch_size, max_spec_len) + + recovered_token_ids = torch.empty_like(draft_token_ids) + sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( + recovered_token_ids, + cu_num_draft_tokens, + draft_token_ids, + draft_probs, + target_probs, + q, + vocab_size, + triton.next_power_of_two(vocab_size), + IS_NGRAM=draft_probs is None, + ) + return recovered_token_ids @triton.jit -def rejection_random_sample_kernel( +def rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] - draft_token_ids_ptr, # [batch_size, max_spec_len] - draft_probs_ptr, # [batch_size, max_spec_len, vocab_size] or None - target_probs_ptr, # [batch_size, max_spec_len, vocab_size] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, # [batch_size] - recovered_token_ids_ptr, # [batch_size, max_spec_len] - uniform_probs_ptr, # [batch_size, UNIFORM_PROBS_LEN] is_greedy_ptr, # [batch_size] max_spec_len, - vocab_size, - IS_NGRAM: tl.constexpr, ): - seq_idx = tl.program_id(0) - is_greedy = tl.load(is_greedy_ptr + seq_idx) - if is_greedy: - # Early exit for greedy sampling requests. + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if not is_greedy: + # Early exit for non-greedy sampling requests. return + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + rejected = False - finished = False - num_generated = 0 - for pos in range(max_spec_len): - if not finished: - token_id = tl.load(draft_token_ids_ptr + seq_idx * max_spec_len + - pos) - if token_id == PLACEHOLDER_TOKEN_ID: - finished = True - else: - if IS_NGRAM: - draft_prob = 1 - else: - # NOTE(woosuk): Here, we assume that draft_prob is nonzero. - draft_prob = tl.load(draft_probs_ptr + - seq_idx * max_spec_len * vocab_size + - pos * vocab_size + token_id) - target_prob = tl.load(target_probs_ptr + - seq_idx * max_spec_len * vocab_size + - pos * vocab_size + token_id) - uniform_prob = tl.load(uniform_probs_ptr + - seq_idx * max_spec_len + pos) - if target_prob / draft_prob >= uniform_prob: - # Accept. - tl.store( - output_token_ids_ptr + seq_idx * (max_spec_len + 1) + - pos, token_id) - num_generated += 1 - else: - # Reject. Use recovered token. - rejected = True - recovered_token_id = tl.load(recovered_token_ids_ptr + - seq_idx * max_spec_len + pos) - tl.store( - output_token_ids_ptr + seq_idx * (max_spec_len + 1) + - pos, recovered_token_id) - num_generated += 1 - finished = True + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + target_argmax_id) + if draft_token_id != target_argmax_id: + # Reject. + rejected = True if not rejected: # If all tokens are accepted, append the bonus token. - bonus_token_id = tl.load(bonus_token_ids_ptr + seq_idx) + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + seq_idx * (max_spec_len + 1) + - num_generated, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) @triton.jit -def rejection_greedy_sample_kernel( +def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] - draft_token_ids_ptr, # [batch_size, max_spec_len] - target_argmax_ptr, # [batch_size, max_spec_len] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] is_greedy_ptr, # [batch_size] max_spec_len, + vocab_size, + IS_NGRAM: tl.constexpr, ): - seq_idx = tl.program_id(0) - is_greedy = tl.load(is_greedy_ptr + seq_idx) - if not is_greedy: - # Early exit for non-greedy sampling requests. + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exit for greedy sampling requests. return + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + rejected = False - finished = False - num_generated = 0 - for pos in range(max_spec_len): - if not finished: - token_id = tl.load(draft_token_ids_ptr + seq_idx * max_spec_len + - pos) - if token_id == PLACEHOLDER_TOKEN_ID: - finished = True + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if IS_NGRAM: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + if draft_prob == 0 or (target_prob / draft_prob >= uniform_prob): + # Accept. + token_id = draft_token_id else: - draft_token_id = tl.load(draft_token_ids_ptr + - seq_idx * max_spec_len + pos) - target_argmax = tl.load(target_argmax_ptr + - seq_idx * max_spec_len + pos) - if draft_token_id == target_argmax: - # Accept. - tl.store( - output_token_ids_ptr + seq_idx * (max_spec_len + 1) + - pos, draft_token_id) - num_generated += 1 - else: - # Reject. - rejected = True - tl.store( - output_token_ids_ptr + seq_idx * (max_spec_len + 1) + - pos, target_argmax) - num_generated += 1 - finished = True + # Reject. Use recovered token. + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + token_id) if not rejected: # If all tokens are accepted, append the bonus token. - bonus_token_id = tl.load(bonus_token_ids_ptr + seq_idx) + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) tl.store( - output_token_ids_ptr + seq_idx * (max_spec_len + 1) + - num_generated, bonus_token_id) + output_token_ids_ptr + req_idx * (max_spec_len + 1) + + num_draft_tokens, bonus_token_id) @triton.jit @@ -423,3 +413,61 @@ def compute_probs_kernel( tl.store(output_prob_ptr + (start_idx + pos) * vocab_size + vocab_offset, output_prob, mask=vocab_offset < vocab_size) + + +@triton.jit +def sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + IS_NGRAM: tl.constexpr, +): + req_idx = tl.program_id(0) + if req_idx == 0: + start_idx = 0 + else: + start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions. + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + + vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + if IS_NGRAM: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + draft_token_id) + # Temporarily zero out the probability of the draft token. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + 0) + prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size) + else: + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size) + prob = target_prob - draft_prob + prob = tl.maximum(prob, TINY) + prob = prob / prob.sum(axis=-1, keep_dims=True) + + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size) + recovered_id = tl.argmax(prob / q, -1) + tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + if IS_NGRAM: + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + orig_prob) From 255d1eec26a3613b33af685fb3717036bab66790 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 14:46:40 -0700 Subject: [PATCH 11/48] kernel Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 65 +++++++++++++++++++---------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 7d18659dee1a..ee7e7377b3a1 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -18,16 +18,30 @@ class RejectionSampler(nn.Module): - def __init__(self, max_num_tokens: int = 32 * 1024): + def __init__( + self, + max_batch_size: int = 8 * 1024, + max_num_draft_tokens: int = 32 * 1024, + ): super().__init__() - self.max_num_tokens = max_num_tokens - self.buffer = torch.empty( - max_num_tokens, + self.max_batch_size = max_batch_size + self.max_num_draft_tokens = max_num_draft_tokens + + self.cu_num_tokens_buffer = torch.empty( + max_batch_size, + dtype=torch.int32, + device="cpu", + pin_memory=is_pin_memory_available(), + ) + self.cu_num_tokens_buffer_np = self.cu_num_tokens_buffer.numpy() + + self.token_ids_buffer = torch.empty( + max_num_draft_tokens, dtype=torch.int64, device="cpu", pin_memory=is_pin_memory_available(), ) - self.buffer_np = self.buffer.numpy() + self.token_ids_buffer_np = self.token_ids_buffer.numpy() def forward( self, @@ -51,11 +65,12 @@ def forward( cu_num_draft_tokens, max_spec_len, ) - # [batch_size, max_spec_len] - draft_token_ids_tensor = self._async_copy_to_device( + draft_token_ids_tensor, cu_num_draft_tokens_tensor = \ + self._async_copy_to_device( draft_token_ids, target_logits.device, ) + output_token_ids = rejection_sample( draft_token_ids_tensor, num_draft_tokens, @@ -83,22 +98,26 @@ def _async_copy_to_device( self, draft_token_ids: list[list[int]], device: torch.device, - ) -> torch.Tensor: - batch_size = len(draft_token_ids) - num_draft_tokens = [len(ids) for ids in draft_token_ids] - max_spec_len = max(num_draft_tokens) - assert batch_size * max_spec_len <= self.max_num_tokens - - draft_token_ids_np = self.buffer_np[:batch_size * max_spec_len] - draft_token_ids_np.fill(PLACEHOLDER_TOKEN_ID) - for i, token_ids in enumerate(draft_token_ids): - start = i * max_spec_len - end = start + len(token_ids) - draft_token_ids_np[start:end] = token_ids - draft_token_ids_cpu = self.buffer[:batch_size * max_spec_len] - draft_token_ids_cpu = draft_token_ids_cpu.view(batch_size, - max_spec_len) - return draft_token_ids_cpu.to(device=device, non_blocking=True) + ) -> tuple[torch.Tensor, torch.Tensor]: + flattened_token_ids: list[int] = [] + cu_num_tokens: list[int] = [] + for token_ids in draft_token_ids: + flattened_token_ids.extend(token_ids) + cu_num_tokens.append(len(token_ids)) + + num_draft_tokens = len(flattened_token_ids) + assert num_draft_tokens <= self.max_num_draft_tokens + self.token_ids_buffer_np[:num_draft_tokens] = flattened_token_ids + draft_token_ids_cpu = self.token_ids_buffer[:num_draft_tokens] + draft_token_ids_gpu = draft_token_ids_cpu.to(device=device, + non_blocking=True) + + batch_size = len(cu_num_tokens) + self.cu_num_tokens_buffer_np[:batch_size] = cu_num_tokens + cu_num_draft_tokens_cpu = self.cu_num_tokens_buffer[:batch_size] + cu_num_draft_tokens_gpu = cu_num_draft_tokens_cpu.to(device=device, + non_blocking=True) + return draft_token_ids_gpu, cu_num_draft_tokens_gpu def rejection_sample( From 22c951501c4ddddda2cb40b5d18d92e3c895c6da Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 14:49:25 -0700 Subject: [PATCH 12/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index ee7e7377b3a1..f6ab188e73ba 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -46,8 +46,6 @@ def __init__( def forward( self, draft_token_ids: list[list[int]], - # [batch_size] - cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] @@ -56,6 +54,11 @@ def forward( bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + token_ids, cu_num_draft_tokens = self._async_copy_to_device( + draft_token_ids, + target_logits.device, + ) + num_draft_tokens = [len(ids) for ids in draft_token_ids] max_spec_len = max(num_draft_tokens) # [num_tokens, vocab_size] @@ -65,14 +68,9 @@ def forward( cu_num_draft_tokens, max_spec_len, ) - draft_token_ids_tensor, cu_num_draft_tokens_tensor = \ - self._async_copy_to_device( - draft_token_ids, - target_logits.device, - ) output_token_ids = rejection_sample( - draft_token_ids_tensor, + token_ids, num_draft_tokens, cu_num_draft_tokens, draft_probs, From c631935860e64390fe86a0d47b4fcbad84c2dc44 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 14:50:20 -0700 Subject: [PATCH 13/48] comment Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index f6ab188e73ba..4cc8dfaf245a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -45,6 +45,7 @@ def __init__( def forward( self, + # batch_size x [0, max_spec_len) draft_token_ids: list[list[int]], # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], From 566caea3bde8112a34460ff1df92c65d4c9f7217 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 14:50:47 -0700 Subject: [PATCH 14/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 4cc8dfaf245a..aedbb3ad5f5e 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -11,6 +11,7 @@ from vllm.v1.sample.metadata import SamplingMetadata logger = init_logger(__name__) + PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 TINY: tl.constexpr = 1.1754943508222875e-38 # torch.finfo(torch.float32).tiny From c427ffdf42010a74c71bfeb1b5aa42e099417e31 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 14:53:41 -0700 Subject: [PATCH 15/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index aedbb3ad5f5e..dd3f9c2d2ec4 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -484,7 +484,7 @@ def sample_recovered_tokens_kernel( q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size) - recovered_id = tl.argmax(prob / q, -1) + recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) if IS_NGRAM: tl.store( From d896f4168e5beed0a4ea22220b52915fcad2b354 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 15:07:12 -0700 Subject: [PATCH 16/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 4 +-- vllm/v1/spec_decode/utils.py | 22 +++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 43 +++++++++++++++++++---------- 3 files changed, 52 insertions(+), 17 deletions(-) create mode 100644 vllm/v1/spec_decode/utils.py diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index dd3f9c2d2ec4..49c718636456 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -224,7 +224,7 @@ def compute_probs( temperature, cu_num_draft_tokens, vocab_size, - triton.next_power_of_two(vocab_size), + triton.next_power_of_2(vocab_size), ) return output_prob @@ -287,7 +287,7 @@ def sample_recovered_tokens( target_probs, q, vocab_size, - triton.next_power_of_two(vocab_size), + triton.next_power_of_2(vocab_size), IS_NGRAM=draft_probs is None, ) return recovered_token_ids diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py new file mode 100644 index 000000000000..584140136778 --- /dev/null +++ b/vllm/v1/spec_decode/utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa +from vllm.v1.worker.gpu_input_batch import InputBatch + + +def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool: + if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs: + # Spec decode doesn't support top_p/top_k sampling. + return False + elif req_id in input_batch.min_p_reqs: + # Spec decode doesn't support min_p sampling. + return False + elif (req_id in input_batch.frequency_penalties_reqs + or req_id in input_batch.presence_penalties_reqs + or req_id in input_batch.repetition_penalties_reqs): + # Spec decode doesn't support penalties. + return False + elif req_id in input_batch.num_logprobs: + # Spec decode doesn't support logprobs. + return False + + return True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8075035bf62..2e4ff4d2d7c8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,9 +35,9 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, - RejectionSampler) +from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -1014,15 +1014,26 @@ def execute_model( sampling_metadata=sampling_metadata, ) else: - target_probs = self.model.sampler.compute_probs( - logits, sampling_metadata) draft_token_ids = [ scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) for req_id in self.input_batch.req_ids ] - sampler_output = self.rejection_sampler(draft_token_ids, - target_probs, - sampling_metadata) + sample_lens = [len(tokens) + 1 for tokens in draft_token_ids] + bonus_logits_idx = np.cumsum(sample_lens) - 1 + sampler_output = self.model.sample( + logits=logits[bonus_logits_idx], + sampling_metadata=sampling_metadata, + ) + bonus_token_ids = sampler_output.sampled_token_ids + + output_token_ids = self.rejection_sampler( + draft_token_ids, + None, # draft_probs + logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -1058,19 +1069,14 @@ def execute_model( valid_sampled_token_ids = sampled_token_ids.tolist() else: # Includes spec decode tokens. - valid_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID - gen_lens = valid_mask.sum(dim=1).tolist() - # TODO(woosuk): Optimize this. - valid_sampled_token_ids = [ - seq.tolist() - for seq in sampled_token_ids[valid_mask].split(gen_lens) - ] + valid_sampled_token_ids = self.rejection_sampler.parse_output( + sampled_token_ids) if not self.use_spec_decode: spec_token_ids = None else: spec_token_ids = self.generate_draft_token_ids( - valid_sampled_token_ids) + valid_sampled_token_ids, sampling_metadata) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1084,6 +1090,7 @@ def execute_model( def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], + sampling_metadata: SamplingMetadata, ) -> list[list[int]]: # TODO(woosuk): Optimize. draft_token_ids: list[list[int]] = [] @@ -1094,6 +1101,12 @@ def generate_draft_token_ids( draft_token_ids.append([]) continue + # Skip requests that require top-p, top-k, etc. + req_id = self.input_batch.req_ids[i] + if not is_spec_decode_supported(req_id, self.input_batch): + draft_token_ids.append([]) + continue + # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids From cb8e69938ac39163fc3d1a41e5d156d9380b6a8b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 15:14:52 -0700 Subject: [PATCH 17/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 49c718636456..8a01633af321 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -245,8 +245,9 @@ def generate_uniform_probs( if n == 0: continue end_idx = start_idx + n - generator = sampling_metadata.generators[req_idx] - uniform_probs[start_idx:end_idx].uniform_(generator=generator) + generator = sampling_metadata.generators.get(req_idx) + if generator is not None: + uniform_probs[start_idx:end_idx].uniform_(generator=generator) start_idx = end_idx return uniform_probs @@ -420,6 +421,7 @@ def compute_probs_kernel( logits = tl.load(logits_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size) + logits = logits.to(dtype=tl.float32) temperature = tl.load(temperature_ptr + req_idx) if temperature == GREEDY_TEMPERATURE: # Greedy sampling. Just return the logits. @@ -427,7 +429,6 @@ def compute_probs_kernel( else: # Random sampling. output_prob = tl.softmax(logits / temperature) - output_prob = output_prob.to(dtype=tl.float32) tl.store(output_prob_ptr + (start_idx + pos) * vocab_size + vocab_offset, output_prob, @@ -480,7 +481,7 @@ def sample_recovered_tokens_kernel( mask=vocab_offset < vocab_size) prob = target_prob - draft_prob prob = tl.maximum(prob, TINY) - prob = prob / prob.sum(axis=-1, keep_dims=True) + prob = prob / tl.sum(prob, axis=-1, keep_dims=True) q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size) From c0bcf5aed6eefe27e68fec703ac54c86f31ba038 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 16:41:00 -0700 Subject: [PATCH 18/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 42 +++++++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 5 ++-- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 8a01633af321..be93d7fa8385 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -56,6 +56,7 @@ def forward( bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + print("draft_token_ids", draft_token_ids) token_ids, cu_num_draft_tokens = self._async_copy_to_device( draft_token_ids, target_logits.device, @@ -83,7 +84,10 @@ def forward( return output_token_ids @staticmethod - def parse_output(output_token_ids: torch.Tensor) -> list[list[int]]: + def parse_output( + output_token_ids: torch.Tensor, + vocab_size: int, + ) -> list[list[int]]: output_token_ids = output_token_ids.tolist() # Preallocate outputs. outputs: list[list[int]] = [[] for _ in output_token_ids] @@ -91,6 +95,9 @@ def parse_output(output_token_ids: torch.Tensor) -> list[list[int]]: for token_id in token_ids: if token_id == PLACEHOLDER_TOKEN_ID: break + # Make sure the token id is in the vocabulary. + if token_id >= vocab_size: + break outputs[i].append(token_id) return outputs @@ -99,11 +106,11 @@ def _async_copy_to_device( draft_token_ids: list[list[int]], device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - flattened_token_ids: list[int] = [] - cu_num_tokens: list[int] = [] - for token_ids in draft_token_ids: - flattened_token_ids.extend(token_ids) - cu_num_tokens.append(len(token_ids)) + flattened_token_ids = sum(draft_token_ids, []) + cu_num_tokens = [0] * len(draft_token_ids) + for i, token_ids in enumerate(draft_token_ids): + prev = cu_num_tokens[i - 1] if i > 0 else 0 + cu_num_tokens[i] = prev + len(token_ids) num_draft_tokens = len(flattened_token_ids) assert num_draft_tokens <= self.max_num_draft_tokens @@ -191,6 +198,8 @@ def rejection_sample( device, ) + print("recovered_token_ids", recovered_token_ids) + print("bonus_token_ids", bonus_token_ids) # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, @@ -420,7 +429,8 @@ def compute_probs_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) logits = tl.load(logits_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size) + mask=vocab_offset < vocab_size, + other=float("-inf")) logits = logits.to(dtype=tl.float32) temperature = tl.load(temperature_ptr + req_idx) if temperature == GREEDY_TEMPERATURE: @@ -471,20 +481,24 @@ def sample_recovered_tokens_kernel( 0) prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size) + mask=vocab_offset < vocab_size, + other=0) else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size) + mask=vocab_offset < vocab_size, + other=0) target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size) - prob = target_prob - draft_prob - prob = tl.maximum(prob, TINY) - prob = prob / tl.sum(prob, axis=-1, keep_dims=True) + mask=vocab_offset < vocab_size, + other=0) + prob = tl.maximum(target_prob - draft_prob, 0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size) + mask=vocab_offset < vocab_size, + other=float("-inf")) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) if IS_NGRAM: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2e4ff4d2d7c8..77838401b670 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1021,7 +1021,7 @@ def execute_model( sample_lens = [len(tokens) + 1 for tokens in draft_token_ids] bonus_logits_idx = np.cumsum(sample_lens) - 1 sampler_output = self.model.sample( - logits=logits[bonus_logits_idx], + logits=logits[bonus_logits_idx], # FIXME: synchronization sampling_metadata=sampling_metadata, ) bonus_token_ids = sampler_output.sampled_token_ids @@ -1070,7 +1070,8 @@ def execute_model( else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids) + sampled_token_ids, self.input_batch.vocab_size) + print("valid_sampled_token_ids", valid_sampled_token_ids) if not self.use_spec_decode: spec_token_ids = None From ae3d7fc2e989323a5b9240ba0100d021464c19ff Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 17:03:10 -0700 Subject: [PATCH 19/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 77838401b670..28cebbf747b8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1021,7 +1021,7 @@ def execute_model( sample_lens = [len(tokens) + 1 for tokens in draft_token_ids] bonus_logits_idx = np.cumsum(sample_lens) - 1 sampler_output = self.model.sample( - logits=logits[bonus_logits_idx], # FIXME: synchronization + logits=logits[bonus_logits_idx], # TODO: Optimize. sampling_metadata=sampling_metadata, ) bonus_token_ids = sampler_output.sampled_token_ids @@ -1072,6 +1072,7 @@ def execute_model( valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.input_batch.vocab_size) print("valid_sampled_token_ids", valid_sampled_token_ids) + print("-" * 100) if not self.use_spec_decode: spec_token_ids = None From 412e2f4f619e1515c2b7f13d6edf1cac88707271 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 17:03:38 -0700 Subject: [PATCH 20/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index be93d7fa8385..6d67e41fc087 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -50,7 +50,7 @@ def forward( draft_token_ids: list[list[int]], # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], - # [num_tokens, vocab_size] + # [num_tokens_with_bonus, vocab_size] target_logits: torch.Tensor, # [batch_size] bonus_token_ids: torch.Tensor, @@ -219,14 +219,19 @@ def rejection_sample( def compute_probs( - logits: torch.Tensor, # [num_tokens, vocab_size] + logits: torch.Tensor, # [num_tokens_with_bonus, vocab_size] temperature: torch.Tensor, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] max_spec_len: int, ) -> torch.Tensor: - output_prob = torch.empty_like(logits, dtype=torch.float32) batch_size = temperature.shape[0] vocab_size = logits.shape[-1] + num_tokens = logits.shape[0] - batch_size + output_prob = torch.empty( + (num_tokens, vocab_size), + dtype=torch.float32, + device=logits.device, + ) compute_probs_kernel[(batch_size, max_spec_len)]( output_prob, logits, @@ -408,7 +413,7 @@ def rejection_random_sample_kernel( @triton.jit def compute_probs_kernel( output_prob_ptr, # [num_tokens, vocab_size] - logits_ptr, # [num_tokens, vocab_size] + logits_ptr, # [num_tokens_with_bonus, vocab_size] temperature_ptr, # [batch_size] cu_num_draft_tokens_ptr, # [batch_size] vocab_size, @@ -427,7 +432,10 @@ def compute_probs_kernel( return vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) - logits = tl.load(logits_ptr + (start_idx + pos) * vocab_size + + # NOTE(woosuk): We need to add `req_idx` to `start_idx + pos` because + # `logits_ptr` has the shape of `[num_tokens_with_bonus, vocab_size]`, + # not `[num_tokens, vocab_size]`. + logits = tl.load(logits_ptr + (start_idx + pos + req_idx) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")) From df66124e3774541979aa9861a6e241a34401def3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 17:06:11 -0700 Subject: [PATCH 21/48] remove Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 3 --- vllm/v1/worker/gpu_model_runner.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 6d67e41fc087..d3e03580d82f 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -56,7 +56,6 @@ def forward( bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - print("draft_token_ids", draft_token_ids) token_ids, cu_num_draft_tokens = self._async_copy_to_device( draft_token_ids, target_logits.device, @@ -198,8 +197,6 @@ def rejection_sample( device, ) - print("recovered_token_ids", recovered_token_ids) - print("bonus_token_ids", bonus_token_ids) # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 28cebbf747b8..cacab00f05aa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1071,8 +1071,6 @@ def execute_model( # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.input_batch.vocab_size) - print("valid_sampled_token_ids", valid_sampled_token_ids) - print("-" * 100) if not self.use_spec_decode: spec_token_ids = None From 704da77ed425f85069c62d1a555cdfa4e4aa1320 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 20:22:06 -0700 Subject: [PATCH 22/48] opt Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 109 ++++++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 39 +++++++++- 2 files changed, 106 insertions(+), 42 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index d3e03580d82f..6302565d26b3 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -7,7 +7,6 @@ import triton.language as tl from vllm.logger import init_logger -from vllm.utils import is_pin_memory_available from vllm.v1.sample.metadata import SamplingMetadata logger = init_logger(__name__) @@ -21,28 +20,40 @@ class RejectionSampler(nn.Module): def __init__( self, + pin_memory: bool, + device: torch.device, max_batch_size: int = 8 * 1024, max_num_draft_tokens: int = 32 * 1024, ): super().__init__() self.max_batch_size = max_batch_size self.max_num_draft_tokens = max_num_draft_tokens + self.pin_memory = pin_memory + self.device = device self.cu_num_tokens_buffer = torch.empty( max_batch_size, dtype=torch.int32, device="cpu", - pin_memory=is_pin_memory_available(), + pin_memory=self.pin_memory, ) self.cu_num_tokens_buffer_np = self.cu_num_tokens_buffer.numpy() + self.cu_num_tokens_buffer_device = torch.empty_like( + self.cu_num_tokens_buffer, + device=self.device, + ) self.token_ids_buffer = torch.empty( max_num_draft_tokens, dtype=torch.int64, device="cpu", - pin_memory=is_pin_memory_available(), + pin_memory=self.pin_memory, ) self.token_ids_buffer_np = self.token_ids_buffer.numpy() + self.token_ids_buffer_device = torch.empty_like( + self.token_ids_buffer, + device=self.device, + ) def forward( self, @@ -57,9 +68,7 @@ def forward( sampling_metadata: SamplingMetadata, ) -> torch.Tensor: token_ids, cu_num_draft_tokens = self._async_copy_to_device( - draft_token_ids, - target_logits.device, - ) + draft_token_ids) num_draft_tokens = [len(ids) for ids in draft_token_ids] max_spec_len = max(num_draft_tokens) @@ -69,6 +78,7 @@ def forward( sampling_metadata.temperature, cu_num_draft_tokens, max_spec_len, + sampling_metadata.all_greedy, ) output_token_ids = rejection_sample( @@ -103,27 +113,32 @@ def parse_output( def _async_copy_to_device( self, draft_token_ids: list[list[int]], - device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - flattened_token_ids = sum(draft_token_ids, []) - cu_num_tokens = [0] * len(draft_token_ids) + start = 0 for i, token_ids in enumerate(draft_token_ids): - prev = cu_num_tokens[i - 1] if i > 0 else 0 - cu_num_tokens[i] = prev + len(token_ids) + end = start + len(token_ids) + self.token_ids_buffer_np[start:end] = token_ids + self.cu_num_tokens_buffer_np[i] = end + start = end + num_draft_tokens = end - num_draft_tokens = len(flattened_token_ids) assert num_draft_tokens <= self.max_num_draft_tokens - self.token_ids_buffer_np[:num_draft_tokens] = flattened_token_ids - draft_token_ids_cpu = self.token_ids_buffer[:num_draft_tokens] - draft_token_ids_gpu = draft_token_ids_cpu.to(device=device, - non_blocking=True) + draft_token_ids_device = ( + self.token_ids_buffer_device[:num_draft_tokens]) + draft_token_ids_device.copy_( + self.token_ids_buffer[:num_draft_tokens], + non_blocking=True, + ) - batch_size = len(cu_num_tokens) - self.cu_num_tokens_buffer_np[:batch_size] = cu_num_tokens - cu_num_draft_tokens_cpu = self.cu_num_tokens_buffer[:batch_size] - cu_num_draft_tokens_gpu = cu_num_draft_tokens_cpu.to(device=device, - non_blocking=True) - return draft_token_ids_gpu, cu_num_draft_tokens_gpu + batch_size = len(draft_token_ids) + assert batch_size <= self.max_batch_size + cu_num_draft_tokens_device = ( + self.cu_num_tokens_buffer_device[:batch_size]) + cu_num_draft_tokens_device.copy_( + self.cu_num_tokens_buffer[:batch_size], + non_blocking=True, + ) + return draft_token_ids_device, cu_num_draft_tokens_device def rejection_sample( @@ -161,6 +176,7 @@ def rejection_sample( is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: + print("GREEDY") # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) rejection_greedy_sample_kernel[(batch_size, )]( @@ -171,9 +187,11 @@ def rejection_sample( bonus_token_ids, is_greedy, max_spec_len, + num_warps=1, ) if sampling_metadata.all_greedy: return output_token_ids + print("RANDOM") # Generate uniform probabilities for rejection sampling. # [num_tokens] @@ -211,6 +229,7 @@ def rejection_sample( max_spec_len, vocab_size, IS_NGRAM=draft_probs is None, + num_warps=1, ) return output_token_ids @@ -220,23 +239,30 @@ def compute_probs( temperature: torch.Tensor, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] max_spec_len: int, + all_greedy: bool, ) -> torch.Tensor: batch_size = temperature.shape[0] vocab_size = logits.shape[-1] num_tokens = logits.shape[0] - batch_size - output_prob = torch.empty( + scaled_logits = torch.empty( (num_tokens, vocab_size), dtype=torch.float32, device=logits.device, ) - compute_probs_kernel[(batch_size, max_spec_len)]( - output_prob, + block_size = 8192 + num_blocks = triton.cdiv(vocab_size, block_size) + compute_probs_kernel[(batch_size, max_spec_len, num_blocks)]( + scaled_logits, logits, temperature, cu_num_draft_tokens, vocab_size, - triton.next_power_of_2(vocab_size), + BLOCK_SIZE=block_size, ) + if all_greedy: + output_prob = scaled_logits + else: + output_prob = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32) return output_prob @@ -305,7 +331,9 @@ def sample_recovered_tokens( return recovered_token_ids -@triton.jit +# NOTE(woosuk): To avoid recompilation, we shouldn't specialize on +# `max_spec_len`. +@triton.jit(do_not_specialize=["max_spec_len"]) def rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] @@ -347,7 +375,9 @@ def rejection_greedy_sample_kernel( num_draft_tokens, bonus_token_id) -@triton.jit +# NOTE(woosuk): To avoid recompilation, we shouldn't specialize on +# `max_spec_len`. +@triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] @@ -409,12 +439,12 @@ def rejection_random_sample_kernel( @triton.jit def compute_probs_kernel( - output_prob_ptr, # [num_tokens, vocab_size] + output_logits_ptr, # [num_tokens, vocab_size] logits_ptr, # [num_tokens_with_bonus, vocab_size] temperature_ptr, # [batch_size] cu_num_draft_tokens_ptr, # [batch_size] vocab_size, - PADDED_VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: @@ -428,26 +458,25 @@ def compute_probs_kernel( if pos >= end_idx - start_idx: return - vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + block_id = tl.program_id(2) + block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) # NOTE(woosuk): We need to add `req_idx` to `start_idx + pos` because # `logits_ptr` has the shape of `[num_tokens_with_bonus, vocab_size]`, # not `[num_tokens, vocab_size]`. logits = tl.load(logits_ptr + (start_idx + pos + req_idx) * vocab_size + - vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf")) + block_offset, + mask=block_offset < vocab_size) logits = logits.to(dtype=tl.float32) temperature = tl.load(temperature_ptr + req_idx) if temperature == GREEDY_TEMPERATURE: # Greedy sampling. Just return the logits. - output_prob = logits + scaled_logits = logits else: # Random sampling. - output_prob = tl.softmax(logits / temperature) - - tl.store(output_prob_ptr + (start_idx + pos) * vocab_size + vocab_offset, - output_prob, - mask=vocab_offset < vocab_size) + scaled_logits = logits / temperature + tl.store(output_logits_ptr + (start_idx + pos) * vocab_size + block_offset, + scaled_logits, + mask=block_offset < vocab_size) @triton.jit diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cacab00f05aa..13fde0611fd9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -151,7 +151,6 @@ def __init__( self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True - self.rejection_sampler = RejectionSampler() # TODO: find a better way to check if we are using ngram. assert self.speculative_config.ngram_prompt_lookup_min, \ "Currently, only ngram spec decode is supported in V1." @@ -164,6 +163,10 @@ def __init__( self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) + self.rejection_sampler = RejectionSampler( + pin_memory=self.pin_memory, + device=self.device, + ) # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -1026,6 +1029,8 @@ def execute_model( ) bonus_token_ids = sampler_output.sampled_token_ids + # torch.cuda.synchronize() + start = time.time() output_token_ids = self.rejection_sampler( draft_token_ids, None, # draft_probs @@ -1034,6 +1039,11 @@ def execute_model( sampling_metadata, ) sampler_output.sampled_token_ids = output_token_ids + end = time.time() + print(f"Rejection sampler CPU took {(end - start) * 1000:.4f} ms") + # torch.cuda.synchronize() + # end = time.time() + # print(f"Rejection sampler GPU took {(end - start) * 1000:.4f} ms") # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -1071,7 +1081,10 @@ def execute_model( # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.input_batch.vocab_size) - + avg_len = sum( + len(ids) + for ids in valid_sampled_token_ids) / len(valid_sampled_token_ids) + print(f"Average length of valid sampled token ids: {avg_len:.2f}") if not self.use_spec_decode: spec_token_ids = None else: @@ -1314,6 +1327,28 @@ def _dummy_sampler_run( "initializing the engine.") from e else: raise e + if self.use_spec_decode: + draft_token_ids = [[0] for _ in range(num_reqs)] + num_tokens = sum(len(ids) for ids in draft_token_ids) + num_tokens_with_bonus = num_tokens + num_reqs + # draft_probs = torch.randn( + # num_tokens, logits.shape[-1], device=self.device, + # dtype=logits.dtype) + draft_probs = None + target_logits = torch.randn(num_tokens_with_bonus, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) + bonus_token_ids = torch.zeros(num_reqs, + device=self.device, + dtype=torch.int64) + self.rejection_sampler( + draft_token_ids, + draft_probs, + target_logits, + bonus_token_ids, + dummy_metadata, + ) return sampler_output def profile_run(self) -> None: From 4f95ca9a9792f3547e1c4a83235253ae4e89d003 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 20:24:08 -0700 Subject: [PATCH 23/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 6302565d26b3..c5a1a7f659aa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -331,8 +331,7 @@ def sample_recovered_tokens( return recovered_token_ids -# NOTE(woosuk): To avoid recompilation, we shouldn't specialize on -# `max_spec_len`. +# NOTE(woosuk): Don't specialize on `max_spec_len` to avoid recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] @@ -375,8 +374,7 @@ def rejection_greedy_sample_kernel( num_draft_tokens, bonus_token_id) -# NOTE(woosuk): To avoid recompilation, we shouldn't specialize on -# `max_spec_len`. +# NOTE(woosuk): Don't specialize on `max_spec_len` to avoid recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] From 803c9dea18176dc02e55b7e5ac9e17e294bf89ac Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 21:20:40 -0700 Subject: [PATCH 24/48] opt softmax & fix recompilation Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/utils.py | 13 +++++++++++++ vllm/v1/sample/rejection_sampler.py | 9 +++++---- vllm/v1/worker/gpu_model_runner.py | 29 +++++++++++++---------------- 3 files changed, 31 insertions(+), 20 deletions(-) create mode 100644 vllm/v1/sample/ops/utils.py diff --git a/vllm/v1/sample/ops/utils.py b/vllm/v1/sample/ops/utils.py new file mode 100644 index 000000000000..d921e8ac46b5 --- /dev/null +++ b/vllm/v1/sample/ops/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + + +# NOTE(woosuk): torch.compile generates faster softmax kernels. +def compiled_softmax(logits: torch.Tensor) -> torch.Tensor: + torch._dynamo.mark_dynamic(logits, index=0) + return _softmax(logits) + + +@torch.compile +def _softmax(logits: torch.Tensor) -> torch.Tensor: + return torch.softmax(logits, dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index c5a1a7f659aa..1117c044e11c 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -8,12 +8,12 @@ from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.utils import compiled_softmax logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 -TINY: tl.constexpr = 1.1754943508222875e-38 # torch.finfo(torch.float32).tiny class RejectionSampler(nn.Module): @@ -72,6 +72,7 @@ def forward( num_draft_tokens = [len(ids) for ids in draft_token_ids] max_spec_len = max(num_draft_tokens) + assert max_spec_len > 0 # [num_tokens, vocab_size] target_probs = compute_probs( target_logits, @@ -176,7 +177,6 @@ def rejection_sample( is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: - print("GREEDY") # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) rejection_greedy_sample_kernel[(batch_size, )]( @@ -191,7 +191,6 @@ def rejection_sample( ) if sampling_metadata.all_greedy: return output_token_ids - print("RANDOM") # Generate uniform probabilities for rejection sampling. # [num_tokens] @@ -244,6 +243,7 @@ def compute_probs( batch_size = temperature.shape[0] vocab_size = logits.shape[-1] num_tokens = logits.shape[0] - batch_size + scaled_logits = torch.empty( (num_tokens, vocab_size), dtype=torch.float32, @@ -259,10 +259,11 @@ def compute_probs( vocab_size, BLOCK_SIZE=block_size, ) + if all_greedy: output_prob = scaled_logits else: - output_prob = torch.softmax(scaled_logits, dim=-1, dtype=torch.float32) + output_prob = compiled_softmax(scaled_logits) return output_prob diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 13fde0611fd9..f1820ba9392c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1029,21 +1029,16 @@ def execute_model( ) bonus_token_ids = sampler_output.sampled_token_ids - # torch.cuda.synchronize() - start = time.time() - output_token_ids = self.rejection_sampler( - draft_token_ids, - None, # draft_probs - logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids - end = time.time() - print(f"Rejection sampler CPU took {(end - start) * 1000:.4f} ms") - # torch.cuda.synchronize() - # end = time.time() - # print(f"Rejection sampler GPU took {(end - start) * 1000:.4f} ms") + has_draft_tokens = any(len(ids) > 0 for ids in draft_token_ids) + if has_draft_tokens: + output_token_ids = self.rejection_sampler( + draft_token_ids, + None, # draft_probs + logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -1339,9 +1334,11 @@ def _dummy_sampler_run( logits.shape[-1], device=self.device, dtype=logits.dtype) + # NOTE(woosuk): Here, we should use int32 because the sampler + # uses int32 for bonus_token_ids. bonus_token_ids = torch.zeros(num_reqs, device=self.device, - dtype=torch.int64) + dtype=torch.int32) self.rejection_sampler( draft_token_ids, draft_probs, From 9cc93499d7dd5d87c3119e64b1e66b15779b0584 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 21:22:46 -0700 Subject: [PATCH 25/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/utils.py b/vllm/v1/sample/ops/utils.py index d921e8ac46b5..7c83040246ae 100644 --- a/vllm/v1/sample/ops/utils.py +++ b/vllm/v1/sample/ops/utils.py @@ -2,8 +2,13 @@ import torch -# NOTE(woosuk): torch.compile generates faster softmax kernels. def compiled_softmax(logits: torch.Tensor) -> torch.Tensor: + """Faster softmax kernel generated by torch.compile. + + Args: + logits: [n, vocab_size] + """ + # NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic. torch._dynamo.mark_dynamic(logits, index=0) return _softmax(logits) From 2b69e513ab215951facfea9561829043021809f0 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 22:00:48 -0700 Subject: [PATCH 26/48] remove envs Signed-off-by: Woosuk Kwon --- vllm/envs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index a36d20a4f8b5..e8978f082c30 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -35,7 +35,6 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None - VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 From 75e93aa713e6de00a4da6831f85ebca962130f3f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 22:32:00 -0700 Subject: [PATCH 27/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1820ba9392c..146ec8fff31d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1076,10 +1076,6 @@ def execute_model( # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.input_batch.vocab_size) - avg_len = sum( - len(ids) - for ids in valid_sampled_token_ids) / len(valid_sampled_token_ids) - print(f"Average length of valid sampled token ids: {avg_len:.2f}") if not self.use_spec_decode: spec_token_ids = None else: From 5a86ff35094adad2f6490eb26e3db6cd6817a5bf Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 15 Mar 2025 22:32:19 -0700 Subject: [PATCH 28/48] fix Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 584140136778..d5329ef7b5ab 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa from vllm.v1.worker.gpu_input_batch import InputBatch From e2232d3814612a3b1ff54d28f1df2f23599283b9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 00:24:36 -0700 Subject: [PATCH 29/48] handle 0 Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 1117c044e11c..c0ae7f0b5ded 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -418,7 +418,9 @@ def rejection_random_sample_kernel( (start_idx + pos) * vocab_size + draft_token_id) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) - if draft_prob == 0 or (target_prob / draft_prob >= uniform_prob): + # NOTE(woosuk): While the draft probability should never be 0, + # we check it to avoid NaNs. If it happens to be 0, we reject. + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: # Accept. token_id = draft_token_id else: From 67183dbca1e66e1d513ab744b227cd76b385bca1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 22:14:29 -0700 Subject: [PATCH 30/48] SpecDecodeMetadata Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 193 +++++++--------------------- vllm/v1/spec_decode/metadata.py | 61 +++++++++ vllm/v1/worker/gpu_model_runner.py | 185 +++++++++++++++----------- 3 files changed, 224 insertions(+), 215 deletions(-) create mode 100644 vllm/v1/spec_decode/metadata.py diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index c0ae7f0b5ded..64e600c13127 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -9,83 +9,43 @@ from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.utils import compiled_softmax +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) PLACEHOLDER_TOKEN_ID: tl.constexpr = -1 GREEDY_TEMPERATURE: tl.constexpr = -1 +# Maximum number of speculative draft tokens allowed per request in a single +# step. This value is chosen to be large enough to handle typical use cases. +MAX_SPEC_LEN = 32 class RejectionSampler(nn.Module): - def __init__( - self, - pin_memory: bool, - device: torch.device, - max_batch_size: int = 8 * 1024, - max_num_draft_tokens: int = 32 * 1024, - ): - super().__init__() - self.max_batch_size = max_batch_size - self.max_num_draft_tokens = max_num_draft_tokens - self.pin_memory = pin_memory - self.device = device - - self.cu_num_tokens_buffer = torch.empty( - max_batch_size, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.cu_num_tokens_buffer_np = self.cu_num_tokens_buffer.numpy() - self.cu_num_tokens_buffer_device = torch.empty_like( - self.cu_num_tokens_buffer, - device=self.device, - ) - - self.token_ids_buffer = torch.empty( - max_num_draft_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory, - ) - self.token_ids_buffer_np = self.token_ids_buffer.numpy() - self.token_ids_buffer_device = torch.empty_like( - self.token_ids_buffer, - device=self.device, - ) - def forward( self, - # batch_size x [0, max_spec_len) - draft_token_ids: list[list[int]], + metadata: SpecDecodeMetadata, # [num_tokens, vocab_size] draft_probs: Optional[torch.Tensor], - # [num_tokens_with_bonus, vocab_size] + # [num_tokens, vocab_size] target_logits: torch.Tensor, # [batch_size] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - token_ids, cu_num_draft_tokens = self._async_copy_to_device( - draft_token_ids) - - num_draft_tokens = [len(ids) for ids in draft_token_ids] - max_spec_len = max(num_draft_tokens) - assert max_spec_len > 0 + assert 0 < metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] target_probs = compute_probs( target_logits, - sampling_metadata.temperature, - cu_num_draft_tokens, - max_spec_len, - sampling_metadata.all_greedy, + metadata.cu_num_draft_tokens, + sampling_metadata, ) output_token_ids = rejection_sample( - token_ids, - num_draft_tokens, - cu_num_draft_tokens, + metadata.draft_token_ids, + metadata.num_draft_tokens, + metadata.max_spec_len, + metadata.cu_num_draft_tokens, draft_probs, target_probs, bonus_token_ids, @@ -111,42 +71,13 @@ def parse_output( outputs[i].append(token_id) return outputs - def _async_copy_to_device( - self, - draft_token_ids: list[list[int]], - ) -> tuple[torch.Tensor, torch.Tensor]: - start = 0 - for i, token_ids in enumerate(draft_token_ids): - end = start + len(token_ids) - self.token_ids_buffer_np[start:end] = token_ids - self.cu_num_tokens_buffer_np[i] = end - start = end - num_draft_tokens = end - - assert num_draft_tokens <= self.max_num_draft_tokens - draft_token_ids_device = ( - self.token_ids_buffer_device[:num_draft_tokens]) - draft_token_ids_device.copy_( - self.token_ids_buffer[:num_draft_tokens], - non_blocking=True, - ) - - batch_size = len(draft_token_ids) - assert batch_size <= self.max_batch_size - cu_num_draft_tokens_device = ( - self.cu_num_tokens_buffer_device[:batch_size]) - cu_num_draft_tokens_device.copy_( - self.cu_num_tokens_buffer[:batch_size], - non_blocking=True, - ) - return draft_token_ids_device, cu_num_draft_tokens_device - def rejection_sample( # [num_tokens] draft_token_ids: torch.Tensor, # [batch_size] num_draft_tokens: list[int], + max_spec_len: int, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [num_tokens, vocab_size] @@ -158,8 +89,7 @@ def rejection_sample( sampling_metadata: SamplingMetadata, ) -> torch.Tensor: batch_size = len(num_draft_tokens) - max_spec_len = max(num_draft_tokens) - num_tokens = sum(num_draft_tokens) + num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] device = target_probs.device assert draft_token_ids.is_contiguous() @@ -234,36 +164,30 @@ def rejection_sample( def compute_probs( - logits: torch.Tensor, # [num_tokens_with_bonus, vocab_size] - temperature: torch.Tensor, # [batch_size] + logits: torch.Tensor, # [num_tokens, vocab_size] cu_num_draft_tokens: torch.Tensor, # [batch_size] - max_spec_len: int, - all_greedy: bool, + sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - batch_size = temperature.shape[0] - vocab_size = logits.shape[-1] - num_tokens = logits.shape[0] - batch_size + if sampling_metadata.all_greedy: + return logits - scaled_logits = torch.empty( - (num_tokens, vocab_size), + num_tokens = logits.shape[0] + expanded_temperature = torch.empty( + (num_tokens, ), dtype=torch.float32, device=logits.device, ) - block_size = 8192 - num_blocks = triton.cdiv(vocab_size, block_size) - compute_probs_kernel[(batch_size, max_spec_len, num_blocks)]( - scaled_logits, - logits, - temperature, + expand_kernel[(num_tokens, )]( + expanded_temperature, + sampling_metadata.temperature, cu_num_draft_tokens, - vocab_size, - BLOCK_SIZE=block_size, + GREEDY_TEMPERATURE, # replace_from + 1, # replace_to + MAX_NUM_TOKENS=MAX_SPEC_LEN, + num_warps=1, ) - - if all_greedy: - output_prob = scaled_logits - else: - output_prob = compiled_softmax(scaled_logits) + scaled_logits = logits / expanded_temperature.unsqueeze(-1) + output_prob = compiled_softmax(scaled_logits) return output_prob @@ -438,46 +362,29 @@ def rejection_random_sample_kernel( num_draft_tokens, bonus_token_id) -@triton.jit -def compute_probs_kernel( - output_logits_ptr, # [num_tokens, vocab_size] - logits_ptr, # [num_tokens_with_bonus, vocab_size] - temperature_ptr, # [batch_size] - cu_num_draft_tokens_ptr, # [batch_size] - vocab_size, - BLOCK_SIZE: tl.constexpr, +@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +def expand_kernel( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + MAX_NUM_TOKENS: tl.constexpr, ): req_idx = tl.program_id(0) - if req_idx == 0: + if req_idx == 0: # noqa: SIM108 start_idx = 0 else: - start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1) - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) - - # Early exit for out-of-range positions. - pos = tl.program_id(1) - if pos >= end_idx - start_idx: - return - - block_id = tl.program_id(2) - block_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - # NOTE(woosuk): We need to add `req_idx` to `start_idx + pos` because - # `logits_ptr` has the shape of `[num_tokens_with_bonus, vocab_size]`, - # not `[num_tokens, vocab_size]`. - logits = tl.load(logits_ptr + (start_idx + pos + req_idx) * vocab_size + - block_offset, - mask=block_offset < vocab_size) - logits = logits.to(dtype=tl.float32) - temperature = tl.load(temperature_ptr + req_idx) - if temperature == GREEDY_TEMPERATURE: - # Greedy sampling. Just return the logits. - scaled_logits = logits - else: - # Random sampling. - scaled_logits = logits / temperature - tl.store(output_logits_ptr + (start_idx + pos) * vocab_size + block_offset, - scaled_logits, - mask=block_offset < vocab_size) + start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_tokens_ptr + req_idx) + num_tokens = end_idx - start_idx + + src_val = tl.load(input_ptr + req_idx) + src_val = tl.where(src_val == replace_from, replace_to, src_val) + offset = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx + offset, + src_val, + mask=offset < num_tokens) @triton.jit diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py new file mode 100644 index 000000000000..c95d6fca0d8a --- /dev/null +++ b/vllm/v1/spec_decode/metadata.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import numpy as np +import torch + + +@dataclass +class SpecDecodeMetadata: + + # [num_tokens] + draft_token_ids: torch.Tensor + # [batch_size] + num_draft_tokens: list[int] + # [batch_size] + cu_num_draft_tokens: torch.Tensor + # [num_tokens] + target_logits_indices: torch.Tensor + # [batch_size] + bonus_logits_indices: torch.Tensor + # [num_tokens + batch_size] + logits_indices: torch.Tensor + + def __post_init__(self): + self.max_spec_len = max(self.num_draft_tokens) + + @classmethod + def make_dummy_for_profiling( + cls, + draft_token_ids: list[list[int]], + device: torch.device, + ) -> "SpecDecodeMetadata": + batch_size = len(draft_token_ids) + num_draft_tokens = [len(ids) for ids in draft_token_ids] + flattened_draft_token_ids = sum(draft_token_ids, []) + num_tokens = len(flattened_draft_token_ids) + + draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids, + dtype=torch.int32, + device=device) + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to( + device) + + target_logits_indices = torch.zeros(num_tokens, + dtype=torch.int32, + device=device) + bonus_logits_indices = torch.zeros(batch_size, + dtype=torch.int32, + device=device) + logits_indices = torch.zeros(num_tokens + batch_size, + dtype=torch.int32, + device=device) + return cls( + draft_token_ids=draft_token_ids_tensor, + num_draft_tokens=num_draft_tokens, + cu_num_draft_tokens=cu_num_draft_tokens_tensor, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9929ce1178d5..c06c8183a763 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -36,6 +36,7 @@ ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache @@ -170,10 +171,7 @@ def __init__( self.speculative_config.ngram_prompt_lookup_min, self.speculative_config.num_speculative_tokens, ) - self.rejection_sampler = RejectionSampler( - pin_memory=self.pin_memory, - device=self.device, - ) + self.rejection_sampler = RejectionSampler() # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -464,7 +462,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor]: + ) -> tuple[FlashAttentionMetadata, torch.Tensor, + Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -589,22 +588,33 @@ def _prepare_inputs( use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 - if use_spec_decode: - logits_indices = self._calc_spec_decode_metadata( - scheduler_output, cu_num_tokens) - else: + if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = attn_metadata.query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices + return attn_metadata, logits_indices, spec_decode_metadata def _compute_cascade_attn_prefix_len( self, @@ -744,50 +754,79 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): def _calc_spec_decode_metadata( self, - scheduler_output: "SchedulerOutput", - cu_num_tokens: np.ndarray, - ) -> torch.Tensor: - # Get the number of spec decode tokens for each request. - num_reqs = self.input_batch.num_reqs - num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32) - for i, req_id in enumerate(self.input_batch.req_ids): - num_spec_decode_tokens[i] = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) - - # Get spec decode logits indices. - # E.g., num_scheduled_tokens: [4, 100, 3, 100, 2] - # cu_num_tokens: [4, 104, 107, 207, 209] - # num_spec_tokens_list: [3, 0, 2, 0, 1] - # num_sampled_tokens: [4, 1, 3, 1, 2] - # spec_decode_logits_indices: - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - num_sampled_tokens = num_spec_decode_tokens + 1 - # logits_start_loc: [0, 103, 104, 206, 207] - logits_start_loc = cu_num_tokens - num_sampled_tokens - # [0, 103, 104, 206, 207] -> - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] - logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens) - # The following three lines: - # [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11] - cu_num_sampled_tokens = np.cumsum(num_sampled_tokens) - # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9] - # -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - cumsums_sampled_offsets = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens) - # Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] - # -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - total_num_sampled_tokens = num_sampled_tokens.sum() - sampled_arange = (self.arange_np[:total_num_sampled_tokens] - - cumsums_sampled_offsets) - - # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] -> - # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - spec_decode_logits_indices = logits_start_loc + sampled_arange - return torch.from_numpy(spec_decode_logits_indices).to( + num_draft_tokens: np.ndarray, + cu_num_scheduled_tokens: np.ndarray, + ) -> SpecDecodeMetadata: + # Inputs: + # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] + # num_draft_tokens: [ 3, 0, 2, 0, 1] + # Outputs: + # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] + # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, + # 206, 207, 208] + # target_logits_indices: [ 0, 1, 2, 5, 6, 9] + # bonus_logits_indices: [ 3, 4, 7, 8, 10] + + # Compute the logits indices. + # [4, 1, 3, 1, 2] + num_sampled_tokens = num_draft_tokens + 1 + # Step 1. [4, 5, 8, 9, 11] + cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32) + total_num_sampled_tokens = cu_num_sampled_tokens[-1] + # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9] + cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens, + num_sampled_tokens) + # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets + # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] + logits_indices = np.repeat( + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] + logits_indices += arange + + # Compute the bonus logits indices. + bonus_logits_indices = cu_num_sampled_tokens - 1 + + # Compute the draft logits indices. + # [3, 3, 5, 5, 6] + cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) + total_num_draft_tokens = cu_num_draft_tokens[-1] + # [0, 0, 0, 3, 3, 5] + cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens, + num_draft_tokens) + # [0, 1, 2, 0, 1, 0] + arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets + # [0, 0, 0, 5, 5, 9] + target_logits_indices = np.repeat( + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + # [0, 1, 2, 5, 6, 9] + target_logits_indices += arange + + # TODO: Optimize the CPU -> GPU copy. + cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( + self.device, non_blocking=True) + logits_indices = torch.from_numpy(logits_indices).to(self.device, + non_blocking=True) + target_logits_indices = torch.from_numpy(target_logits_indices).to( + self.device, non_blocking=True) + bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( self.device, non_blocking=True) + # Compute the draft token ids. + # draft_token_indices: [ 1, 2, 3, 105, 106, 208] + draft_token_ids = self.input_ids[logits_indices] + draft_token_ids = draft_token_ids[target_logits_indices + 1] + + metadata = SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=num_draft_tokens.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) + return metadata + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: @@ -943,7 +982,8 @@ def execute_model( encoder_outputs = [] # Prepare the decoder inputs. - attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + attn_metadata, logits_indices, spec_decode_metadata = ( + self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1018,34 +1058,30 @@ def execute_model( # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata - if not self.use_spec_decode: + if spec_decode_metadata is None: sampler_output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) else: - draft_token_ids = [ - scheduler_output.scheduled_spec_decode_tokens.get(req_id, []) - for req_id in self.input_batch.req_ids - ] - sample_lens = [len(tokens) + 1 for tokens in draft_token_ids] - bonus_logits_idx = np.cumsum(sample_lens) - 1 + # TODO(woosuk): Optimize the memory usage. + bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] sampler_output = self.model.sample( - logits=logits[bonus_logits_idx], # TODO: Optimize. + logits=bonus_logits, sampling_metadata=sampling_metadata, ) bonus_token_ids = sampler_output.sampled_token_ids - has_draft_tokens = any(len(ids) > 0 for ids in draft_token_ids) - if has_draft_tokens: - output_token_ids = self.rejection_sampler( - draft_token_ids, - None, # draft_probs - logits, - bonus_token_ids, - sampling_metadata, - ) - sampler_output.sampled_token_ids = output_token_ids + # TODO(woosuk): Optimize the memory usage. + target_logits = logits[spec_decode_metadata.target_logits_indices] + output_token_ids = self.rejection_sampler( + spec_decode_metadata, + None, # draft_probs + target_logits, + bonus_token_ids, + sampling_metadata, + ) + sampler_output.sampled_token_ids = output_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. @@ -1083,6 +1119,7 @@ def execute_model( # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( sampled_token_ids, self.input_batch.vocab_size) + if not self.use_spec_decode: spec_token_ids = None else: @@ -1327,6 +1364,10 @@ def _dummy_sampler_run( raise e if self.use_spec_decode: draft_token_ids = [[0] for _ in range(num_reqs)] + dummy_spec_decode_metadata = ( + SpecDecodeMetadata.make_dummy_for_profiling( + draft_token_ids, self.device)) + num_tokens = sum(len(ids) for ids in draft_token_ids) num_tokens_with_bonus = num_tokens + num_reqs # draft_probs = torch.randn( @@ -1343,7 +1384,7 @@ def _dummy_sampler_run( device=self.device, dtype=torch.int32) self.rejection_sampler( - draft_token_ids, + dummy_spec_decode_metadata, draft_probs, target_logits, bonus_token_ids, From 2adc2afb64ffb11b4e7bd6eb9ecbc87f165943e6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 22:36:09 -0700 Subject: [PATCH 31/48] scaled softmax Signed-off-by: Woosuk Kwon --- vllm/v1/sample/ops/utils.py | 18 +++++++++++++++--- vllm/v1/sample/rejection_sampler.py | 10 +++++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/vllm/v1/sample/ops/utils.py b/vllm/v1/sample/ops/utils.py index 7c83040246ae..a54e20603064 100644 --- a/vllm/v1/sample/ops/utils.py +++ b/vllm/v1/sample/ops/utils.py @@ -1,18 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Union + import torch -def compiled_softmax(logits: torch.Tensor) -> torch.Tensor: +def compiled_softmax( + logits: torch.Tensor, + temperature: Union[float, torch.Tensor] = 1.0, +) -> torch.Tensor: """Faster softmax kernel generated by torch.compile. Args: logits: [n, vocab_size] + temperature: [n] or float """ # NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic. torch._dynamo.mark_dynamic(logits, index=0) - return _softmax(logits) + if isinstance(temperature, torch.Tensor): + torch._dynamo.mark_dynamic(temperature, index=0) + return _softmax(logits, temperature) @torch.compile -def _softmax(logits: torch.Tensor) -> torch.Tensor: +def _softmax( + logits: torch.Tensor, + temperature: Union[float, torch.Tensor], +) -> torch.Tensor: + logits = logits / temperature return torch.softmax(logits, dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 64e600c13127..217e4d174f8a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -66,7 +66,7 @@ def parse_output( if token_id == PLACEHOLDER_TOKEN_ID: break # Make sure the token id is in the vocabulary. - if token_id >= vocab_size: + if not (0 <= token_id < vocab_size): break outputs[i].append(token_id) return outputs @@ -172,12 +172,13 @@ def compute_probs( return logits num_tokens = logits.shape[0] + batch_size = cu_num_draft_tokens.shape[0] expanded_temperature = torch.empty( - (num_tokens, ), + (num_tokens, 1), dtype=torch.float32, device=logits.device, ) - expand_kernel[(num_tokens, )]( + expand_kernel[(batch_size, )]( expanded_temperature, sampling_metadata.temperature, cu_num_draft_tokens, @@ -186,8 +187,7 @@ def compute_probs( MAX_NUM_TOKENS=MAX_SPEC_LEN, num_warps=1, ) - scaled_logits = logits / expanded_temperature.unsqueeze(-1) - output_prob = compiled_softmax(scaled_logits) + output_prob = compiled_softmax(logits, expanded_temperature) return output_prob From d74961664732c95fd04c8f01633fcd04e686bd38 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 22:47:48 -0700 Subject: [PATCH 32/48] minor opt Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 217e4d174f8a..f168033c5275 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -58,17 +58,14 @@ def parse_output( output_token_ids: torch.Tensor, vocab_size: int, ) -> list[list[int]]: - output_token_ids = output_token_ids.tolist() - # Preallocate outputs. - outputs: list[list[int]] = [[] for _ in output_token_ids] - for i, token_ids in enumerate(output_token_ids): - for token_id in token_ids: - if token_id == PLACEHOLDER_TOKEN_ID: - break - # Make sure the token id is in the vocabulary. - if not (0 <= token_id < vocab_size): - break - outputs[i].append(token_id) + output_token_ids_np = output_token_ids.cpu().numpy() + # Create mask for valid tokens. + valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) & + (output_token_ids_np < vocab_size)) + outputs = [ + row[valid_mask[i]].tolist() + for i, row in enumerate(output_token_ids_np) + ] return outputs From c400c99cf57897bd80ebdceadd008b7bd7c84f9c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 22:57:13 -0700 Subject: [PATCH 33/48] docstrings Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 5679aa3422a1..842a04704644 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -50,10 +50,37 @@ def forward( draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_logits: torch.Tensor, - # [batch_size] + # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + ''' + Args: + metadata: + Metadata for spec decoding. + draft_probs (Optional[torch.Tensor]): + Probability distribution for the draft tokens. Shape is + [num_tokens, vocab_size]. Can be None if probabilities are + not provided, which is the case for ngram spec decode. + target_logits (torch.Tensor): + Target model's logits probability distribution. + Shape is [num_tokens, vocab_size]. Here, probabilities from + different requests are flattened into a single tensor because + this is the shape of the output logits. + bonus_token_ids_tensor (torch.Tensor): + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling + process such as top_p, top_k sampling. + sampling_metadata (SamplingMetadata): + Additional metadata needed for sampling, such as temperature, + top-k/top-p parameters, or other relevant information. + Returns: + output_token_ids (torch.Tensor): + A tensor containing the final output token IDs. + ''' assert 0 < metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] target_probs = compute_probs( From e197c3b6215a138cab751c40c51a228c9f59eb9a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 23:07:05 -0700 Subject: [PATCH 34/48] docstring Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 47 +++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 842a04704644..70dd9435fcb2 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -172,7 +172,7 @@ def rejection_sample( uniform_probs = generate_uniform_probs( num_tokens, num_draft_tokens, - sampling_metadata, + sampling_metadata.generators, device, ) @@ -213,6 +213,22 @@ def compute_probs( cu_num_draft_tokens: torch.Tensor, # [batch_size] sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + """Compute probability distribution from logits based on sampling metadata. + + This function applies temperature scaling to the logits and converts + them to probabilities using softmax. For greedy decoding + + Args: + logits: Input logits tensor to be converted to probabilities + cu_num_draft_tokens: Cumulative number of of draft tokens. + sampling_metadata: Metadata containing sampling parameters such + as temperature and whether greedy sampling is used + + Returns: + torch.Tensor: Probability distribution (softmax of scaled logits) + if non-greedy sampling is used, otherwise returns the + original logits. + """ if sampling_metadata.all_greedy: return logits @@ -239,9 +255,34 @@ def compute_probs( def generate_uniform_probs( num_tokens: int, num_draft_tokens: list[int], - sampling_metadata: SamplingMetadata, + generators: dict[int, torch.Generator], device: torch.device, ) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + if available. + + This method creates a tensor of shape `(num_tokens, )` filled + with uniform random values in the range [0, 1). If `generators` is provided, + the requests with their own seeds will use the provided `torch.Generator` + for reproducibility. The samples for the other requests will be generated + without a seed. + + Args: + num_tokens : int + Total number of tokens. + num_draft_tokens : List[List[int]] + Number of draft tokens per request. + generators : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. + device : torch.device + The device on which to allocate the tensor. + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(num_tokens, )` containing uniform + random values in the range [0, 1). + """ uniform_probs = torch.rand( (num_tokens, ), dtype=torch.float32, @@ -252,7 +293,7 @@ def generate_uniform_probs( if n == 0: continue end_idx = start_idx + n - generator = sampling_metadata.generators.get(req_idx) + generator = generators.get(req_idx) if generator is not None: uniform_probs[start_idx:end_idx].uniform_(generator=generator) start_idx = end_idx From 830ccd4bf3339876d3ffc904cca979c562d0689e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 23:09:02 -0700 Subject: [PATCH 35/48] trailing white space Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 70dd9435fcb2..d79d21970d6b 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -22,13 +22,13 @@ class RejectionSampler(nn.Module): """ - The implementation strictly follows the algorithm described in + The implementation strictly follows the algorithm described in https://arxiv.org/abs/2211.17192. However, we want to clarify the terminology used in the implementation: - accepted tokens: tokens that are accepted based on the relationship + accepted tokens: tokens that are accepted based on the relationship between the "raw" draft and target probabilities. recovered tokens: tokens that are sampled based on the adjusted probability - distribution, which is derived from both the draft and target + distribution, which is derived from both the draft and target probabilities. bonus tokens: If all proposed tokens are accepted, the bonus token is added to the @@ -38,8 +38,8 @@ class RejectionSampler(nn.Module): sampling process. For example, we can use top_p, top_k sampling for bonus tokens, while spec decode does not support these sampling strategies. - output tokens: - Tokens are finally generated with the rejection sampler. + output tokens: + Tokens are finally generated with the rejection sampler. output tokens = accepted tokens + recovered tokens + bonus tokens """ @@ -68,11 +68,11 @@ def forward( different requests are flattened into a single tensor because this is the shape of the output logits. bonus_token_ids_tensor (torch.Tensor): - A tensor containing bonus tokens. Shape is [batch_size, 1]. - Bonus tokens are added to the end of the sequence if all - proposed tokens are accepted. We generate the bonus tokens - outside of the rejection sampler with the default sampling - strategy. It allows for more flexibility in the sampling + A tensor containing bonus tokens. Shape is [batch_size, 1]. + Bonus tokens are added to the end of the sequence if all + proposed tokens are accepted. We generate the bonus tokens + outside of the rejection sampler with the default sampling + strategy. It allows for more flexibility in the sampling process such as top_p, top_k sampling. sampling_metadata (SamplingMetadata): Additional metadata needed for sampling, such as temperature, @@ -259,7 +259,7 @@ def generate_uniform_probs( device: torch.device, ) -> torch.Tensor: """ - Generates a batch of uniform random samples, with optional seeding + Generates a batch of uniform random samples, with optional seeding if available. This method creates a tensor of shape `(num_tokens, )` filled @@ -274,13 +274,13 @@ def generate_uniform_probs( num_draft_tokens : List[List[int]] Number of draft tokens per request. generators : Optional[Dict[int, torch.Generator]] - A dictionary mapping indices in the batch to + A dictionary mapping indices in the batch to `torch.Generator` objects. device : torch.device The device on which to allocate the tensor. Returns: uniform_rand : torch.Tensor - A tensor of shape `(num_tokens, )` containing uniform + A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ uniform_probs = torch.rand( From dcd9db24a28d1340f47041cd09b0c7aeb1d8671e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 23:27:28 -0700 Subject: [PATCH 36/48] minor Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index d79d21970d6b..baa29a5cba47 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -529,7 +529,9 @@ def sample_recovered_tokens_kernel( other=float("-inf")) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + if IS_NGRAM: + # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) From 418b36b9c30f069a9c5ead6d403b57920ac4b54c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Mar 2025 23:49:51 -0700 Subject: [PATCH 37/48] is_all_greedy Signed-off-by: Woosuk Kwon --- vllm/v1/sample/metadata.py | 2 +- vllm/v1/sample/rejection_sampler.py | 19 ++++++++++++++----- vllm/v1/worker/gpu_input_batch.py | 10 +++++----- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 7e339e2a597d..e97e1235fb36 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -9,7 +9,7 @@ @dataclass class SamplingMetadata: - temperature: torch.Tensor + temperature: Optional[torch.Tensor] all_greedy: bool all_random: bool diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index baa29a5cba47..1634044e6429 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -150,7 +150,10 @@ def rejection_sample( ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) - is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE + if sampling_metadata.all_greedy: + is_greedy = None + else: + is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) @@ -342,7 +345,7 @@ def sample_recovered_tokens( return recovered_token_ids -# NOTE(woosuk): Don't specialize on `max_spec_len` to avoid recompilation. +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] @@ -350,11 +353,16 @@ def rejection_greedy_sample_kernel( draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, # [batch_size] - is_greedy_ptr, # [batch_size] + is_greedy_ptr, # [batch_size] or None max_spec_len, ): req_idx = tl.program_id(0) - is_greedy = tl.load(is_greedy_ptr + req_idx) + # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, + # re-compilation may happen during runtime when is_greedy_ptr is None. + if is_greedy_ptr is None: + is_greedy = True + else: + is_greedy = tl.load(is_greedy_ptr + req_idx) if not is_greedy: # Early exit for non-greedy sampling requests. return @@ -385,7 +393,7 @@ def rejection_greedy_sample_kernel( num_draft_tokens, bonus_token_id) -# NOTE(woosuk): Don't specialize on `max_spec_len` to avoid recompilation. +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] @@ -448,6 +456,7 @@ def rejection_random_sample_kernel( num_draft_tokens, bonus_token_id) +# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation. @triton.jit(do_not_specialize=["replace_from", "replace_to"]) def expand_kernel( output_ptr, # [num_tokens] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 202e75369be8..55d5429a8935 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -525,11 +525,11 @@ def refresh_sampling_metadata(self): def _make_sampling_metadata(self) -> SamplingMetadata: num_reqs = self.num_reqs - # NOTE(woosuk): Even if all requests are greedy, we copy the - # temperature tensor for simplicity. The temperature tensor is used - # for speculative decoding. - temperature = copy_slice(self.temperature_cpu_tensor, self.temperature, - num_reqs) + if not self.all_greedy: + temperature = copy_slice(self.temperature_cpu_tensor, + self.temperature, num_reqs) + else: + temperature = None if not self.no_top_p: copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) if not self.no_top_k: From 731622bda971659324e0890e5a2a63aaf2d1c81d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 01:36:43 -0700 Subject: [PATCH 38/48] fix test Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 27 ++++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 84139a40b544..1dd5a4cb05c9 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -6,7 +6,8 @@ import torch.nn.functional as F from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler +from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, + RejectionSampler) DEVICE = "cpu" @@ -89,9 +90,11 @@ def test_early_mismatch(sampler): device=logits.device) output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) - expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], - dtype=torch.int, - device=logits.device) + expected = torch.tensor( + [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + dtype=torch.int, + device=logits.device, + ) assert torch.equal(output, expected) @@ -107,7 +110,7 @@ def test_multiple_sequences(sampler): [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) - expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]], + expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) @@ -155,10 +158,12 @@ def test_multiple_mismatches(sampler): [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) - expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID], - [4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], - dtype=torch.int, - device=logits.device) + expected = torch.tensor( + [[1, 2, 7, PLACEHOLDER_TOKEN_ID], + [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], + dtype=torch.int, + device=logits.device, + ) assert torch.equal(output, expected) @@ -166,9 +171,9 @@ def test_multiple_mismatches(sampler): "spec_tokens,output_tokens,expected", [ ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus - ([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch + ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], - [[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches + [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches ]) def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): """Parametrized test for various matching scenarios""" From 36247ae41433c43c26b0533b85b8c0555dfc7185 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 08:10:57 -0700 Subject: [PATCH 39/48] rename Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 45 ++++++++++++++--------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 1dd5a4cb05c9..19e9f9ad764d 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -13,7 +13,7 @@ @pytest.fixture -def sampler(): +def rejection_sampler(): return RejectionSampler() @@ -62,7 +62,7 @@ def create_sampling_metadata( ########################### Tests for Greedy Sampling ################### -def test_perfect_match(sampler): +def test_perfect_match(rejection_sampler): """Test when output tokens perfectly match speculated tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token @@ -72,14 +72,15 @@ def test_perfect_match(sampler): bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, + metadata) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_early_mismatch(sampler): +def test_early_mismatch(rejection_sampler): """Test when there's an early mismatch in tokens""" spec_tokens = [[1, 2, 3]] output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 @@ -89,7 +90,8 @@ def test_early_mismatch(sampler): bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, + metadata) expected = torch.tensor( [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, @@ -98,7 +100,7 @@ def test_early_mismatch(sampler): assert torch.equal(output, expected) -def test_multiple_sequences(sampler): +def test_multiple_sequences(rejection_sampler): """Test handling multiple sequences of speculated tokens""" spec_tokens = [[1, 2], [3]] output_tokens = [[1, 2, 5], [3, @@ -109,14 +111,15 @@ def test_multiple_sequences(sampler): bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, + metadata) expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_single_token_sequence(sampler): +def test_single_token_sequence(rejection_sampler): """Test handling sequences with single token""" spec_tokens = [[1]] output_tokens = [[1, 2]] # Single token with bonus token 2 @@ -126,12 +129,13 @@ def test_single_token_sequence(sampler): bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, + metadata) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_empty_sequence(sampler): +def test_empty_sequence(rejection_sampler): """Test handling empty sequence of speculated tokens""" spec_tokens: list[list[int]] = [[]] output_tokens = [[5]] # Just the bonus token @@ -141,12 +145,13 @@ def test_empty_sequence(sampler): bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, + metadata) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) -def test_multiple_mismatches(sampler): +def test_multiple_mismatches(rejection_sampler): """Test handling multiple sequences with mismatches""" spec_tokens = [[1, 2, 3], [4, 5, 6]] output_tokens = [[1, 2, 7, 6], [4, 8, 6, @@ -157,7 +162,8 @@ def test_multiple_mismatches(sampler): bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, + metadata) expected = torch.tensor( [[1, 2, 7, PLACEHOLDER_TOKEN_ID], [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], @@ -175,14 +181,16 @@ def test_multiple_mismatches(sampler): ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches ]) -def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): +def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, + expected): """Parametrized test for various matching scenarios""" metadata = create_sampling_metadata(all_greedy=True) logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], device=logits.device) - output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) + output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, + metadata) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) @@ -195,7 +203,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): @pytest.mark.parametrize("batch_size", [1, 4, 8]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) @pytest.mark.parametrize("n_rep", [20]) -def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, +def test_deterministic_when_seeded(rejection_sampler, k: int, vocab_size: int, batch_size: int, frac_seeded: float, n_rep: int): draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -222,8 +230,9 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, sampling_metadata = create_sampling_metadata(all_greedy=False, generators=seeded_seqs) - rep_result = sampler(draft_token_ids.tolist(), draft_probs, - bonus_token_ids, target_probs, sampling_metadata) + rep_result = rejection_sampler(draft_token_ids.tolist(), draft_probs, + bonus_token_ids, target_probs, + sampling_metadata) results.append(rep_result) From 52f9f4e58a4b8b651bd12c3331908d7a72a7532c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 08:14:27 -0700 Subject: [PATCH 40/48] rename Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 2 +- vllm/v1/spec_decode/metadata.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 19e9f9ad764d..a9ec027952de 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -9,7 +9,7 @@ from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) -DEVICE = "cpu" +DEVICE = "cuda" @pytest.fixture diff --git a/vllm/v1/spec_decode/metadata.py b/vllm/v1/spec_decode/metadata.py index c95d6fca0d8a..1cf650d5fa56 100644 --- a/vllm/v1/spec_decode/metadata.py +++ b/vllm/v1/spec_decode/metadata.py @@ -25,7 +25,7 @@ def __post_init__(self): self.max_spec_len = max(self.num_draft_tokens) @classmethod - def make_dummy_for_profiling( + def make_dummy( cls, draft_token_ids: list[list[int]], device: torch.device, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ed42b1d18c1f..ae5b72cb4f50 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1363,9 +1363,8 @@ def _dummy_sampler_run( raise e if self.use_spec_decode: draft_token_ids = [[0] for _ in range(num_reqs)] - dummy_spec_decode_metadata = ( - SpecDecodeMetadata.make_dummy_for_profiling( - draft_token_ids, self.device)) + dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids, self.device) num_tokens = sum(len(ids) for ids in draft_token_ids) num_tokens_with_bonus = num_tokens + num_reqs From 9cbab9f1fa5314131a122637f605f05083828b17 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 09:57:40 -0700 Subject: [PATCH 41/48] fix Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 194 ++++++++++++++++------ vllm/v1/sample/rejection_sampler.py | 12 +- 2 files changed, 151 insertions(+), 55 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index a9ec027952de..25e510d3dd70 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -8,6 +8,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, RejectionSampler) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata DEVICE = "cuda" @@ -17,10 +18,11 @@ def rejection_sampler(): return RejectionSampler() -def create_logits_tensor(token_ids: list[list[int]], +def create_logits_tensor(output_token_ids: list[list[int]], vocab_size: int = 100) -> torch.Tensor: """Helper function to create logits tensor that will produce desired token ids on argmax""" + token_ids = [tokens[:-1] for tokens in output_token_ids] num_total_tokens = sum(len(tokens) for tokens in token_ids) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) start_loc = 0 @@ -32,15 +34,22 @@ def create_logits_tensor(token_ids: list[list[int]], def create_sampling_metadata( - all_greedy: bool, - generators: Optional[dict[int, Any]] = None) -> SamplingMetadata: + all_greedy: bool, + temperature: Optional[torch.Tensor] = None, + generators: Optional[dict[int, Any]] = None, +) -> SamplingMetadata: """Create a v1 sampling metadata object with all_greedy set to the given value. Either all greedy or all random sampling is used. """ generators = generators or {} + if all_greedy: + temperature = None + else: + assert temperature is not None + return SamplingMetadata( - temperature=torch.tensor([]), + temperature=temperature, all_greedy=all_greedy, all_random=not all_greedy, top_p=None, @@ -71,9 +80,16 @@ def test_perfect_match(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - - output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, - metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) @@ -89,9 +105,16 @@ def test_early_mismatch(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - - output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, - metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor( [[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, @@ -110,9 +133,16 @@ def test_multiple_sequences(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - - output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, - metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device) @@ -128,9 +158,16 @@ def test_single_token_sequence(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - - output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, - metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) @@ -144,9 +181,16 @@ def test_empty_sequence(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) - - output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, - metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) assert torch.equal(output, expected) @@ -161,9 +205,16 @@ def test_multiple_mismatches(rejection_sampler): logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor( [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) - - output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, - metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected = torch.tensor( [[1, 2, 7, PLACEHOLDER_TOKEN_ID], [4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]], @@ -188,9 +239,16 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, logits = create_logits_tensor(output_tokens) bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], device=logits.device) - - output = rejection_sampler(spec_tokens, None, bonus_token_tensor, logits, - metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens, + device=logits.device) + + output = rejection_sampler( + spec_decode_metadata, + draft_probs=None, + target_logits=logits, + bonus_token_ids=bonus_token_tensor, + sampling_metadata=metadata, + ) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) @@ -203,21 +261,31 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, @pytest.mark.parametrize("batch_size", [1, 4, 8]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) @pytest.mark.parametrize("n_rep", [20]) -def test_deterministic_when_seeded(rejection_sampler, k: int, vocab_size: int, - batch_size: int, frac_seeded: float, - n_rep: int): - draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size * (k + 1), - vocab_size, - dtype=torch.float32) +def test_deterministic_when_seeded( + rejection_sampler, + k: int, + vocab_size: int, + batch_size: int, + frac_seeded: float, + n_rep: int, +): + num_tokens = batch_size * k + draft_probs = torch.rand(num_tokens, + vocab_size, + dtype=torch.float32, + device=DEVICE) + draft_probs = F.softmax(draft_probs, dim=-1) + target_logits = torch.rand_like(draft_probs) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), - dtype=torch.int64) + dtype=torch.int64, + device=DEVICE) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), - dtype=torch.int64) + dtype=torch.int64, + device=DEVICE) seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded @@ -228,11 +296,21 @@ def test_deterministic_when_seeded(rejection_sampler, k: int, vocab_size: int, for i in range(batch_size) if seeded_mask[i] } + temperature = torch.ones(batch_size, + dtype=torch.float32, + device=DEVICE) sampling_metadata = create_sampling_metadata(all_greedy=False, + temperature=temperature, generators=seeded_seqs) - rep_result = rejection_sampler(draft_token_ids.tolist(), draft_probs, - bonus_token_ids, target_probs, - sampling_metadata) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids.tolist(), device=DEVICE) + rep_result = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) results.append(rep_result) @@ -271,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution(): num_reference_probs = 100 # Prepare draft, target, and reference probability distributions - draft_probs, target_probs = (F.softmax( - torch.rand(vocab_size, dtype=torch.float32), - dim=-1, - ) for _ in range(2)) + draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32), + dim=-1) + target_logits = torch.rand(vocab_size, dtype=torch.float32) + target_probs = F.softmax(target_logits, dim=-1) reference_probs = F.softmax( torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), dim=-1, @@ -287,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution(): for num_samples in sample_sizes: # Sample using rejection sampling. rej_sample_probs = estimate_rejection_sampling_pdf( - draft_probs, target_probs, k, vocab_size, num_samples) + draft_probs, target_logits, k, vocab_size, num_samples) rej_sample_probs = rej_sample_probs.to(DEVICE) # Average distance from reference probs. @@ -327,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float: def estimate_rejection_sampling_pdf( draft_probs: torch.Tensor, - target_probs: torch.Tensor, + target_logits: torch.Tensor, k: int, vocab_size: int, num_samples: int, @@ -337,35 +415,45 @@ def estimate_rejection_sampling_pdf( Args: draft_probs: Draft probability distribution. - target_probs: Target probability distribution. + target_logits: Target logits. num_samples: Number of samples to draw. Returns: Estimated probability distribution of the output tokens. """ - sampler = RejectionSampler() - # Repeat draft probs num_samples times. + rejection_sampler = RejectionSampler() + num_tokens = num_samples * k + # Repeat draft probs num_samples * k times. draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) - # Repeat target probs num_samples * (k + 1) times. - target_probs = target_probs.reshape(1, 1, vocab_size).repeat( - num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size) + # Repeat target probs num_tokens times. + target_logits = target_logits.reshape(1, vocab_size).repeat( + num_tokens, vocab_size) # Randomly sample draft token ids from draft probs. draft_token_ids = torch.multinomial(draft_probs[:, 0, :], num_samples=k, replacement=True).reshape( num_samples, k) + draft_probs = draft_probs.view(num_tokens, vocab_size) # Bonus tokens not used but required. bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, device=DEVICE).repeat(num_samples, 1) - sampling_metadata = create_sampling_metadata(all_greedy=False) - output_token_ids = sampler(draft_token_ids.tolist(), draft_probs, - bonus_token_ids, target_probs, - sampling_metadata) + temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE) + sampling_metadata = create_sampling_metadata(all_greedy=False, + temperature=temperature) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids.tolist(), device=bonus_token_ids.device) + output_token_ids = rejection_sampler( + spec_decode_metadata, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) output_token_ids = output_token_ids[:, :-1].flatten() hist = torch.histogram(output_token_ids.to(dtype=torch.float, diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 1634044e6429..ffa0aebd689b 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -81,7 +81,7 @@ def forward( output_token_ids (torch.Tensor): A tensor containing the final output token IDs. ''' - assert 0 < metadata.max_spec_len <= MAX_SPEC_LEN + assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] target_probs = compute_probs( target_logits, @@ -129,10 +129,15 @@ def rejection_sample( draft_probs: Optional[torch.Tensor], # [num_tokens, vocab_size] target_probs: torch.Tensor, - # [batch_size] + # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: + assert draft_token_ids.ndim == 1 + assert draft_probs is None or draft_probs.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 + assert target_probs.ndim == 2 + batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] vocab_size = target_probs.shape[-1] @@ -141,6 +146,7 @@ def rejection_sample( assert draft_probs is None or draft_probs.is_contiguous() assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() + assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. output_token_ids = torch.empty( @@ -232,6 +238,8 @@ def compute_probs( if non-greedy sampling is used, otherwise returns the original logits. """ + assert logits.ndim == 2 + assert cu_num_draft_tokens.ndim == 1 if sampling_metadata.all_greedy: return logits From 8b7a39855907b9c09856d2fc160c810280e63f8d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 10:12:10 -0700 Subject: [PATCH 42/48] fix test Signed-off-by: Woosuk Kwon --- tests/v1/sample/test_rejection_sampler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index 25e510d3dd70..8c423e367ef5 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -428,8 +428,7 @@ def estimate_rejection_sampling_pdf( vocab_size).repeat(num_samples, k, 1) # Repeat target probs num_tokens times. - target_logits = target_logits.reshape(1, vocab_size).repeat( - num_tokens, vocab_size) + target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1) # Randomly sample draft token ids from draft probs. draft_token_ids = torch.multinomial(draft_probs[:, 0, :], From 40f334a2c78358f1f322a211961934c30a0b3fa4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 12:47:00 -0700 Subject: [PATCH 43/48] comment Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e9cd1565e5d1..7338f0037dc4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1367,8 +1367,9 @@ def _dummy_sampler_run( logits.shape[-1], device=self.device, dtype=logits.dtype) - # NOTE(woosuk): Here, we should use int32 because the sampler - # uses int32 for bonus_token_ids. + # NOTE(woosuk): Here, we should use int32 because the sampler uses + # int32 for bonus_token_ids. If the dtype mismatches, re-compilation + # will occur at runtime. bonus_token_ids = torch.zeros(num_reqs, device=self.device, dtype=torch.int32) From 6935bfd60ff9c2072868f05be2dd8b1f8b0e2d27 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 12:51:50 -0700 Subject: [PATCH 44/48] comment Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index ffa0aebd689b..e903dba7f903 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -301,6 +301,8 @@ def generate_uniform_probs( ) start_idx = 0 for req_idx, n in enumerate(num_draft_tokens): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. if n == 0: continue end_idx = start_idx + n @@ -335,6 +337,8 @@ def sample_recovered_tokens( ) q.exponential_() for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. if num_draft_tokens[i] > 0: q[i].exponential_(generator=generator) From 0baa33e20838db036fc5bc79e661d36319af004d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 17 Mar 2025 13:49:30 -0700 Subject: [PATCH 45/48] fix shape mismatch Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7338f0037dc4..657333c6d84c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1358,12 +1358,11 @@ def _dummy_sampler_run( draft_token_ids, self.device) num_tokens = sum(len(ids) for ids in draft_token_ids) - num_tokens_with_bonus = num_tokens + num_reqs # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens_with_bonus, + target_logits = torch.randn(num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype) From aaf23161452eb1b69726c2775dd1fc3f34e61f6f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 12:18:50 -0700 Subject: [PATCH 46/48] fix docstrings Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index e903dba7f903..6595ee94d4c0 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -224,19 +224,20 @@ def compute_probs( ) -> torch.Tensor: """Compute probability distribution from logits based on sampling metadata. - This function applies temperature scaling to the logits and converts - them to probabilities using softmax. For greedy decoding + This function applies temperature scaling to the logits and converts + them to probabilities using softmax. For greedy decoding, it returns + the original logits. Args: - logits: Input logits tensor to be converted to probabilities - cu_num_draft_tokens: Cumulative number of of draft tokens. - sampling_metadata: Metadata containing sampling parameters such - as temperature and whether greedy sampling is used + logits: Input logits tensor to be converted to probabilities. + cu_num_draft_tokens: Cumulative number of draft tokens. + sampling_metadata: Metadata containing sampling parameters such as + temperature and whether greedy sampling is used. Returns: - torch.Tensor: Probability distribution (softmax of scaled logits) - if non-greedy sampling is used, otherwise returns the - original logits. + torch.Tensor: Probability distribution (softmax of scaled logits) + if non-greedy sampling is used, otherwise returns the + original logits. """ assert logits.ndim == 2 assert cu_num_draft_tokens.ndim == 1 From 531068e0e3b5697e01168d5324d39d0c708075ed Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 12:22:28 -0700 Subject: [PATCH 47/48] fix dtype Signed-off-by: Woosuk Kwon --- vllm/v1/outputs.py | 2 +- vllm/v1/sample/rejection_sampler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index edae654b5d33..6f46417170f6 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -46,7 +46,7 @@ class SamplerOutput: # [num_reqs, max_num_generated_tokens] # Different requests can have different number of generated tokens. # All requests are padded to max_num_generated_tokens. - # INVALID_TOKEN_ID (-1 by default) is used for padding. + # PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding. sampled_token_ids: torch.Tensor logprobs_tensors: Optional[LogprobsTensors] diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 6595ee94d4c0..563710109a93 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -151,7 +151,7 @@ def rejection_sample( # Create output buffer. output_token_ids = torch.empty( (batch_size, max_spec_len + 1), - dtype=torch.int64, + dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. device=device, ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) From 69c88b8baa5e6af17b3b5b30f95039d5f6a85784 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 18 Mar 2025 12:27:04 -0700 Subject: [PATCH 48/48] add comment Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 563710109a93..6284ae4b490a 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -526,6 +526,8 @@ def sample_recovered_tokens_kernel( orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0)