-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[V1][CUDA Graph] Fix attention metadata tensor sizes for padded batches #24002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
b7d124c
a42e3ce
f586e2b
93071cf
f6b7d63
9acd529
d6a5a9a
4a0e9df
0895405
3f495e2
01db86e
5780e9a
00c3280
5d838ed
370e700
a335b80
89ca98e
bbbc8fb
997f71e
c907ef3
d88842f
39562aa
0d997cf
ccfb764
021882c
1d533bc
e48ab54
f3d09f5
d23b110
c3eba9b
d00b29c
df7edde
f95af46
6160f1c
cebcc54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,16 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from math import ceil | ||
| from typing import Optional | ||
|
|
||
| from typing_extensions import TypeAlias | ||
|
|
||
| from vllm.config import CUDAGraphMode, VllmConfig | ||
| from vllm.forward_context import BatchDescriptor | ||
| from vllm.logger import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
| CUDAGraphKey: TypeAlias = tuple[int, bool] | ||
|
|
||
|
|
||
| class CudagraphDispatcher: | ||
|
|
@@ -31,9 +38,11 @@ def __init__(self, vllm_config: VllmConfig): | |
| self.cudagraph_mode = self.compilation_config.cudagraph_mode | ||
|
|
||
| # Dict to store valid cudagraph dispatching keys. | ||
| self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = { | ||
| CUDAGraphMode.PIECEWISE: set(), | ||
| CUDAGraphMode.FULL: set(), | ||
| self.cudagraph_keys: dict[ | ||
| CUDAGraphMode, dict[CUDAGraphKey, BatchDescriptor] | ||
| ] = { | ||
| CUDAGraphMode.PIECEWISE: {}, | ||
| CUDAGraphMode.FULL: {}, | ||
| } | ||
|
|
||
| not_use_piecewise_compilation = ( | ||
|
|
@@ -61,7 +70,8 @@ def add_cudagraph_key( | |
| assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( | ||
| f"Invalid cudagraph runtime mode for keys: {runtime_mode}" | ||
| ) | ||
| self.cudagraph_keys[runtime_mode].add(batch_descriptor) | ||
| key = (batch_descriptor.num_tokens, batch_descriptor.uniform_decode) | ||
| self.cudagraph_keys[runtime_mode][key] = batch_descriptor | ||
|
|
||
| def initialize_cudagraph_keys( | ||
| self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int | ||
|
|
@@ -71,11 +81,24 @@ def initialize_cudagraph_keys( | |
| # Note: we create all valid keys for cudagraph here but do not | ||
| # guarantee all keys would be used. For example, if we allow lazy | ||
| # capturing in future PR, some keys may never be triggered. | ||
| if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: | ||
| # Add mixed mode keys with proper num_reqs calculation | ||
| if (mixed_mode := cudagraph_mode.mixed_mode()) in ( | ||
| CUDAGraphMode.PIECEWISE, | ||
| CUDAGraphMode.FULL, | ||
| ): | ||
| for bs in self.compilation_config.cudagraph_capture_sizes: | ||
| num_reqs = ( | ||
| self.calculate_num_reqs_for_tokens( | ||
| bs, uniform_decode_query_len, False | ||
| ) | ||
| if mixed_mode == CUDAGraphMode.FULL | ||
| else None | ||
| ) | ||
| self.add_cudagraph_key( | ||
| cudagraph_mode.mixed_mode(), | ||
| BatchDescriptor(num_tokens=bs, uniform_decode=False), | ||
| mixed_mode, | ||
| BatchDescriptor( | ||
| num_tokens=bs, uniform_decode=False, num_reqs=num_reqs | ||
| ), | ||
| ) | ||
|
|
||
| # if decode cudagraph mode is FULL, and we don't already have mixed | ||
|
|
@@ -94,12 +117,38 @@ def initialize_cudagraph_keys( | |
| if x <= max_num_tokens and x >= uniform_decode_query_len | ||
| ] | ||
| for bs in cudagraph_capture_sizes_for_decode: | ||
| num_reqs = self.calculate_num_reqs_for_tokens( | ||
| bs, uniform_decode_query_len, True | ||
| ) | ||
| self.add_cudagraph_key( | ||
| CUDAGraphMode.FULL, | ||
| BatchDescriptor(num_tokens=bs, uniform_decode=True), | ||
| BatchDescriptor( | ||
| num_tokens=bs, uniform_decode=True, num_reqs=num_reqs | ||
| ), | ||
| ) | ||
|
|
||
| self.keys_initialized = True | ||
|
|
||
| def calculate_num_reqs_for_tokens( | ||
|
||
| self, num_tokens: int, uniform_decode_query_len: int, uniform_decode: bool | ||
| ) -> int: | ||
| max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs | ||
|
|
||
| if uniform_decode: | ||
| num_reqs = ceil(num_tokens / uniform_decode_query_len) | ||
| return min(num_reqs, max_num_seqs) | ||
| else: | ||
| return min(num_tokens, max_num_seqs) | ||
|
|
||
| def _is_compatible( | ||
| self, batch_descriptor: BatchDescriptor, candidate: BatchDescriptor | ||
| ) -> bool: | ||
| """Check if candidate cudagraph can handle the batch request.""" | ||
| if candidate.num_reqs is None: | ||
| return True | ||
| assert batch_descriptor.num_reqs is not None | ||
| return candidate.num_reqs >= batch_descriptor.num_reqs | ||
|
|
||
| def dispatch( | ||
| self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False | ||
| ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: | ||
|
|
@@ -113,21 +162,19 @@ def dispatch( | |
| if not self.keys_initialized: | ||
| return CUDAGraphMode.NONE, None | ||
|
|
||
| non_uniform_key = batch_descriptor.non_uniform | ||
| # if a batch use cascade attention, bypass checking full cudagraphs | ||
| if not use_cascade_attn: | ||
| # check if key exists for full cudagraph | ||
| if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: | ||
| return CUDAGraphMode.FULL, batch_descriptor | ||
| num_tokens, uniform_decode = ( | ||
| batch_descriptor.num_tokens, | ||
| batch_descriptor.uniform_decode, | ||
| ) | ||
LucasWilkinson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # otherwise, check if non-uniform key exists | ||
| if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: | ||
| return CUDAGraphMode.FULL, non_uniform_key | ||
| candidates = [(CUDAGraphMode.FULL, (num_tokens, uniform_decode))] | ||
| if uniform_decode: | ||
| candidates.append((CUDAGraphMode.FULL, (num_tokens, False))) | ||
| candidates.append((CUDAGraphMode.PIECEWISE, (num_tokens, False))) | ||
|
|
||
| # also check if non-uniform key exists for more "general" | ||
| # piecewise cudagraph | ||
| if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: | ||
| return CUDAGraphMode.PIECEWISE, non_uniform_key | ||
| for mode, key in candidates: | ||
|
||
| candidate = self.cudagraph_keys[mode].get(key) | ||
| if candidate and self._is_compatible(batch_descriptor, candidate): | ||
| return mode, candidate | ||
|
|
||
| # finally, just return no cudagraphs | ||
| return CUDAGraphMode.NONE, None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a dataclass with utilities for easier and more understandable handling as well as more docs?