diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index b56cf61e782c..ae0b2a64b09b 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren ```python class BatchDescriptor(NamedTuple): num_tokens: int - uniform_decode: bool = False + num_reqs: int + uniform: bool = False + has_lora: bool = False ``` -where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`. +where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`. -The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode. +The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. !!! note The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs). diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 44bc2a4cda31..7eb692f7b835 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple): """ num_tokens: int - uniform_decode: bool = False + num_reqs: int | None = None """ - False can also be used for an uniform decode batch to dispatch to the - cudagraph supporting non-uniform batches. + Number of requests in the batch. Can be None for PIECEWISE cudagraphs where + were the cudagraphs can handle any number of requests. + """ + uniform: bool = False + """ + True if all the requests in the batch have the same number of tokens. """ has_lora: bool = False """ Whether this batch has active LoRA adapters. """ - @property - def non_uniform(self) -> "BatchDescriptor": + def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": """ - Return a non-uniform version of current batch descriptor. + Return a relaxed version of current batch descriptor that is still compatible + with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs). """ return BatchDescriptor( - self.num_tokens, uniform_decode=False, has_lora=self.has_lora + self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora ) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 07a0ab41a9e0..416993539e54 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -708,31 +708,12 @@ def build( if num_decodes > 0: pure_decode = num_prefills == 0 - # possible required padding for cudagraph replay use_cudagraph = ( self.enable_cuda_graph and pure_decode and num_decode_tokens <= self._decode_cudagraph_max_bs ) - if use_cudagraph: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_decode_tokens - ) - # Carefully fulfill the padding region with reasonable value - # on cpu. - # Make sure paged_kv_indptr_cpu is not decreasing - self.paged_kv_indptr_cpu[ - 1 + num_decodes : 1 + num_input_tokens - ].fill_(paged_kv_indptr_cpu[-1]) - # Fill the remaining paged_kv_last_page_len_cpu with 1. - # This is because flashinfer treats 0 as a full page - # instead of empty. - self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_( - 1 - ) - - else: - num_input_tokens = num_decode_tokens + num_input_tokens = num_decode_tokens attn_metadata.decode_wrapper = self._get_decode_wrapper( num_input_tokens, use_cudagraph diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 07dfbc766acd..c63b577ddf6b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -845,7 +845,9 @@ def split_decodes_and_prefills( if require_uniform: is_prefill = query_lens != query_lens[0] else: - is_prefill = query_lens > decode_threshold + # 0-query len indicates a padded request; leave this at the back + # of the batch with the prefills + is_prefill = query_lens > decode_threshold | query_lens == 0 if not torch.any(is_prefill): return num_reqs, 0, num_tokens, 0 diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index b480ac78f23c..32ee4cb85b13 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -4,6 +4,9 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor +from vllm.logger import init_logger + +logger = init_logger(__name__) class CudagraphDispatcher: @@ -29,6 +32,11 @@ def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.cudagraph_mode = self.compilation_config.cudagraph_mode + self.uniform_decode_query_len = ( + 1 + if not self.vllm_config.speculative_config + else 1 + self.vllm_config.speculative_config.num_speculative_tokens + ) # Dict to store valid cudagraph dispatching keys. self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { @@ -55,6 +63,32 @@ def __init__(self, vllm_config: VllmConfig): self.keys_initialized = False + def _create_padded_batch_descriptor( + self, num_tokens: int, uniform_decode: bool, has_lora: bool + ) -> BatchDescriptor: + max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs + uniform_decode_query_len = self.uniform_decode_query_len + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) + + if uniform_decode: + num_reqs = num_tokens // uniform_decode_query_len + assert num_tokens % uniform_decode_query_len == 0 + assert num_reqs <= max_num_seqs + return BatchDescriptor( + num_tokens=num_tokens_padded, + num_reqs=num_reqs, + uniform=uniform_decode, + has_lora=has_lora, + ) + num_reqs = min(num_tokens_padded, max_num_seqs) + + return BatchDescriptor( + num_tokens=num_tokens_padded, + num_reqs=num_reqs, + uniform=uniform_decode, + has_lora=has_lora, + ) + def add_cudagraph_key( self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor ): @@ -86,9 +120,7 @@ def initialize_cudagraph_keys( ): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), - BatchDescriptor( - num_tokens=bs, uniform_decode=False, has_lora=has_lora - ), + self._create_padded_batch_descriptor(bs, False, has_lora), ) # if decode cudagraph mode is FULL, and we don't already have mixed @@ -109,40 +141,59 @@ def initialize_cudagraph_keys( for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): self.add_cudagraph_key( CUDAGraphMode.FULL, - BatchDescriptor( - num_tokens=bs, uniform_decode=True, has_lora=has_lora - ), + self._create_padded_batch_descriptor(bs, True, has_lora), ) + self.keys_initialized = True + def _is_compatible( + self, batch_descriptor: BatchDescriptor, candidate: BatchDescriptor + ) -> bool: + """Check if candidate cudagraph can handle the batch request.""" + if candidate.num_reqs is None: + return True + assert batch_descriptor.num_reqs is not None + return candidate.num_reqs >= batch_descriptor.num_reqs + def dispatch( - self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False + self, + num_tokens: int, + num_reqs: int, + uniform_decode: bool, + has_lora: bool, + use_cascade_attn: bool = False, ) -> tuple[CUDAGraphMode, BatchDescriptor | None]: """ Given conditions(e.g.,batch descriptor and if using cascade attention), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). + + `num_reqs` reserved for future use; making sure callsites have access to this + information. """ # if not initialized, just skip dispatching. if not self.keys_initialized: return CUDAGraphMode.NONE, None - non_uniform_key = batch_descriptor.non_uniform - # if a batch use cascade attention, bypass checking full cudagraphs + batch_descriptor = self._create_padded_batch_descriptor( + num_tokens, uniform_decode, has_lora + ) + relaxed_batch_descriptor = batch_descriptor.relax_for_mixed_batch_cudagraphs() + if not use_cascade_attn: # check if key exists for full cudagraph if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: return CUDAGraphMode.FULL, batch_descriptor - # otherwise, check if non-uniform key exists - if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, non_uniform_key + # otherwise, check if the relaxed key exists + if relaxed_batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, relaxed_batch_descriptor - # also check if non-uniform key exists for more "general" + # also check if the relaxed key exists for more "general" # piecewise cudagraph - if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: - return CUDAGraphMode.PIECEWISE, non_uniform_key + if relaxed_batch_descriptor in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: + return CUDAGraphMode.PIECEWISE, relaxed_batch_descriptor - # finally, just return no cudagraphs - return CUDAGraphMode.NONE, None + # finally, just return no cudagraphs and a trivial batch descriptor + return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fbd3e5f31316..3b8fcad3a23b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1055,17 +1055,13 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", num_scheduled_tokens: np.ndarray, - max_num_scheduled_tokens: int, ) -> tuple[ torch.Tensor, SpecDecodeMetadata | None, - UBatchSlices | None, - torch.Tensor | None, ]: """ :return: tuple[ logits_indices, spec_decode_metadata, - ubatch_slices, num_tokens_across_dp, ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -1175,28 +1171,6 @@ def _prepare_inputs( self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] - num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens - num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded) - uniform_decode = ( - max_num_scheduled_tokens == self.uniform_decode_query_len - ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - - # Disable DP padding when running eager to avoid excessive padding when - # running prefills. This lets us set enforce_eager on the prefiller in - # a P/D setup and still use CUDA graphs (enabled by this padding) on the - # decoder. - allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - - ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( - num_tokens_unpadded=num_tokens_unpadded, - parallel_config=self.parallel_config, - allow_microbatching=True, - allow_dp_padding=allow_dp_padding, - num_tokens_padded=num_tokens_padded, - uniform_decode=uniform_decode, - num_scheduled_tokens_per_request=num_scheduled_tokens, - ) - self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens ) @@ -1287,8 +1261,6 @@ def _prepare_inputs( return ( logits_indices, spec_decode_metadata, - ubatch_slices, - num_tokens_across_dp, ) def _build_attention_metadata( @@ -1473,6 +1445,7 @@ def _build_attention_metadata( def _compute_cascade_attn_prefix_lens( self, num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, num_common_prefix_blocks: list[int], ) -> list[list[int]] | None: """ @@ -1495,6 +1468,7 @@ def _compute_cascade_attn_prefix_lens( # 0 if cascade attention should not be used cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, + num_computed_tokens, num_common_prefix_blocks[kv_cache_gid], attn_group.kv_cache_spec, attn_group.get_metadata_builder(), @@ -1507,6 +1481,7 @@ def _compute_cascade_attn_prefix_lens( def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, + num_computed_tokens: np.ndarray, num_common_prefix_blocks: int, kv_cache_spec: KVCacheSpec, attn_metadata_builder: AttentionMetadataBuilder, @@ -1573,10 +1548,7 @@ def _compute_cascade_attn_prefix_len( # and the second kernel will get an empty input. While this is not # a fundamental problem, our current implementation does not support # this case. - num_reqs = len(num_scheduled_tokens) - common_prefix_len = min( - common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min() - ) + common_prefix_len = min(common_prefix_len, num_computed_tokens.min()) # common_prefix_len should be a multiple of the block size. common_prefix_len = ( common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size @@ -2158,18 +2130,7 @@ def _pool( pooler_output=pooler_output, ) - def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] - ): - # Use CUDA graphs. - # Add padding to the batch size. - return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) - - # Eager mode. + def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size @@ -2527,36 +2488,76 @@ def execute_model( tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens_np = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = int(num_scheduled_tokens_np.max()) - - ( - logits_indices, - spec_decode_metadata, - ubatch_slices, - num_tokens_across_dp, - ) = self._prepare_inputs( - scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens + num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens + num_tokens_padded = self._pad_for_sequence_parallelism( + num_tokens_unpadded + ) + uniform_decode = ( + (max_num_scheduled_tokens == self.uniform_decode_query_len) + and (num_reqs == max_num_scheduled_tokens) + and (num_tokens_padded == num_tokens_unpadded) ) cascade_attn_prefix_lens = None # Disable cascade attention when using microbatching (DBO) - if self.cascade_attn_enabled and ubatch_slices is None: + if self.cascade_attn_enabled and not self.parallel_config.enable_dbo: # Pre-compute cascade attention prefix lengths - # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens( num_scheduled_tokens_np, + self.input_batch.num_computed_tokens_cpu[:num_reqs], scheduler_output.num_common_prefix_blocks, ) - # TODO(lucas): move cudagraph dispatching here: - # https://github.com/vllm-project/vllm/issues/23789 + # Disable DP padding when running eager to avoid excessive padding when + # running prefills. This lets us set enforce_eager on the prefiller in + # a P/D setup and still use CUDA graphs (enabled by this padding) on the + # decoder. + allow_dp_padding = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) + + ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( + num_tokens_unpadded=num_tokens_unpadded, + parallel_config=self.parallel_config, + allow_microbatching=True, + allow_dp_padding=allow_dp_padding, + num_tokens_padded=num_tokens_padded, + uniform_decode=uniform_decode, + num_scheduled_tokens_per_request=num_scheduled_tokens, + ) + + # Will return an unpadded batch descriptor if cudagraph is not NONE. + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens_padded, + num_reqs=num_reqs, + max_num_scheduled_tokens=max_num_scheduled_tokens, + has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, + use_cascade_attn=cascade_attn_prefix_lens is not None, + ) + ) + + num_tokens_padded = batch_descriptor.num_tokens + num_reqs_padded = ( + batch_descriptor.num_reqs + if batch_descriptor.num_reqs is not None + else num_reqs + ) + + ( + logits_indices, + spec_decode_metadata, + ) = self._prepare_inputs( + scheduler_output, + num_scheduled_tokens_np, + ) - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 attn_metadata, spec_decode_common_attn_metadata = ( self._build_attention_metadata( - total_num_scheduled_tokens=total_num_scheduled_tokens, + total_num_scheduled_tokens=num_reqs_padded, max_num_scheduled_tokens=max_num_scheduled_tokens, - num_reqs=num_reqs, + num_reqs=num_reqs_padded, ubatch_slices=ubatch_slices, logits_indices=logits_indices, use_spec_decode=use_spec_decode, @@ -2573,7 +2574,7 @@ def execute_model( elif num_tokens_across_dp is not None: num_input_tokens = int(num_tokens_across_dp[dp_rank].item()) else: - num_input_tokens = self._get_num_input_tokens( + num_input_tokens = self._pad_for_sequence_parallelism( scheduler_output.total_num_scheduled_tokens ) @@ -2587,21 +2588,6 @@ def execute_model( scheduler_output, num_input_tokens, intermediate_tensors ) - uniform_decode = ( - max_num_scheduled_tokens == self.uniform_decode_query_len - ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - batch_descriptor = BatchDescriptor( - num_tokens=num_input_tokens, - uniform_decode=uniform_decode, - has_lora=len(self.input_batch.lora_id_to_lora_request) > 0, - ) - cudagraph_runtime_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch( - batch_descriptor, - use_cascade_attn=cascade_attn_prefix_lens is not None, - ) - ) - # Set cudagraph mode to none if calc_kv_scales is true. # KV scales calculation involves dynamic operations that are incompatible # with CUDA graph capture. @@ -3565,17 +3551,34 @@ def _dummy_run( ) # filter out the valid batch descriptor - _cg_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch( - BatchDescriptor( - num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode, - has_lora=activate_lora and self.lora_config is not None, - ) + has_lora = activate_lora and self.lora_config is not None + if is_profile or cudagraph_runtime_mode is not None: + # During profiling or CUDA graph capture, don't call dispatch + # Use the explicitly provided mode or NONE for profiling + _cg_mode = ( + cudagraph_runtime_mode + if cudagraph_runtime_mode is not None + else CUDAGraphMode.NONE ) - if not is_profile - else (CUDAGraphMode.NONE, None) - ) + batch_descriptor = BatchDescriptor( + num_tokens=num_tokens_after_padding, + num_reqs=num_reqs, + uniform=uniform_decode, + has_lora=has_lora, + ) + else: + # Normal execution: dispatch to find the appropriate CUDA graph + _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( + num_tokens=num_tokens_after_padding, + num_reqs=num_reqs, + max_num_scheduled_tokens=max_query_len, + has_lora=has_lora, + use_cascade_attn=False, + force_uniform_decode=uniform_decode, + ) + + num_tokens_after_padding = batch_descriptor.num_tokens + if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 19061fcffdf1..4cd7fc20562d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -538,7 +538,9 @@ def execute_model( intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens) + num_input_tokens = self.model_runner._pad_for_sequence_parallelism( + num_scheduled_tokens + ) all_gather_tensors = { "residual": not is_residual_scattered_for_sp( self.vllm_config, num_input_tokens