From a7a16dfa830d3bbca6cb9a22c76c27fb18e4ba41 Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Wed, 2 Apr 2025 20:42:31 +0000 Subject: [PATCH 01/13] Add full cuda graph support in v1 Signed-off-by: Chanh Nguyen --- vllm/compilation/backends.py | 4 +- vllm/compilation/monitor.py | 4 +- vllm/config.py | 27 +++++++++----- vllm/v1/attention/backends/flash_attn.py | 12 +++--- vllm/v1/worker/gpu_model_runner.py | 47 +++++++++++++++++++++--- 5 files changed, 68 insertions(+), 26 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 45988c2e9b0d..d9164f2c1fcc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -279,8 +279,8 @@ def call_module(self, target: torch.fx.node.Target, class VllmBackend: """The compilation backend for `torch.compile` with vLLM. - It is used for compilation level of `CompilationLevel.PIECEWISE`, - where we customize the compilation. + It is used for compilation level of `CompilationLevel.PIECEWISE` or + `CompilationLevel.FULL_GRAPH`, where we customize the compilation. The major work of this backend is to split the graph into piecewise graphs, and pass them to the piecewise backend. diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 786c7c1e1859..1334680d7386 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -17,7 +17,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): torch_compile_start_time = time.time() compilation_config: CompilationConfig = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.PIECEWISE and \ + if compilation_config.level >= CompilationLevel.PIECEWISE and \ compilation_config.debug_dump_path: import depyf path = os.path.join(compilation_config.debug_dump_path, @@ -29,7 +29,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): def end_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config - if compilation_config.level == CompilationLevel.PIECEWISE: + if compilation_config.level >= CompilationLevel.PIECEWISE: logger.info("torch.compile takes %.2f s in total", compilation_config.compilation_time) global context_manager diff --git a/vllm/config.py b/vllm/config.py index 2669d1a13b37..b7319605eb8f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3051,6 +3051,7 @@ class CompilationLevel: DYNAMO_AS_IS = 1 DYNAMO_ONCE = 2 PIECEWISE = 3 + FULL_GRAPH = 4 class CompilationConfig(BaseModel): @@ -3063,6 +3064,7 @@ class CompilationConfig(BaseModel): - 1: dynamo as is. - 2: dynamo once. - 3: piecewise compilation. + - 4: full compilation. - debug_dump_path: the path to dump the debug information. - cache_dir: the directory to store the compiled graph, to accelerate Inductor compilation. By default, it will use @@ -3074,6 +3076,7 @@ class CompilationConfig(BaseModel): We use string to avoid serialization issues when using compilation in a distributed setting. When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). + When the compilation level is 4, the backend is used for the full graph. - custom_ops: fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op names to enable (prefixed with a '+'), @@ -3246,7 +3249,7 @@ def __repr__(self) -> str: @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" - if cli_value in ["0", "1", "2", "3"]: + if cli_value in ["0", "1", "2", "3", "4"]: return cls(level=int(cli_value)) # do not use `eval`, it is dangerous and can execute arbitrary code dict_value = ast.literal_eval(cli_value) @@ -3313,7 +3316,7 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: # TODO: pass user-specified backend to piecewise compilation # merge with the config use_inductor - assert self.level == CompilationLevel.PIECEWISE + assert self.level >= CompilationLevel.PIECEWISE from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) @@ -3368,13 +3371,16 @@ def init_with_cudagraph_sizes(self, self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): - # If default, override splitting ops for piecewise cudagraph on V1. # NOTE: this function needs to be called - if not self.splitting_ops: - self.splitting_ops = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - ] + if not self.splitting_ops: + if self.level == CompilationLevel.PIECEWISE: + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + elif self.level == CompilationLevel.FULL_GRAPH: + self.splitting_ops = [] + @dataclass @@ -3600,7 +3606,8 @@ def __post_init__(self): self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_noop = False - self.compilation_config.level = CompilationLevel.PIECEWISE + if self.compilation_config.level < CompilationLevel.PIECEWISE: + self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() self._set_cudagraph_sizes() @@ -3773,7 +3780,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): logger.debug("disabled custom ops: %s", vllm_config.compilation_config.disabled_custom_ops) if check_compile and \ - vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + vllm_config.compilation_config.level >= CompilationLevel.PIECEWISE \ and compilation_counter.num_models_seen == num_models_seen: # If the model supports compilation, # compilation_counter.num_models_seen should be increased diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 92e4ffd0371a..04112be40069 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -109,14 +109,13 @@ def reorder_batch(self, input_batch: "InputBatch", def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - self.runner.device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, - non_blocking=True) + query_start_loc = self.runner.query_start_loc[:num_reqs + 1] + seq_lens = self.runner.seq_lens[:num_reqs] + + block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True).long() + slot_mapping = self.runner.slot_mapping[:num_actual_tokens] use_cascade = common_prefix_len > 0 if use_cascade: @@ -152,7 +151,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, ) return attn_metadata - class FlashAttentionImpl(AttentionImpl): def __init__( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 513806332efe..045f81ad2644 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -182,7 +182,7 @@ def __init__( ) self.use_cuda_graph = (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE + >= CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. @@ -203,6 +203,19 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + self.query_start_loc = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.seq_lens = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.slot_mapping = torch.zeros( + self.max_num_tokens, + # CPU slot_mapping is int32, but + # this one must be int64 + dtype=torch.int64, + device=self.device) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -581,6 +594,19 @@ def _prepare_inputs( scheduler_output.num_common_prefix_blocks, ) + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + self.slot_mapping[:total_num_scheduled_tokens].copy_( + self.slot_mapping_cpu[:total_num_scheduled_tokens], + non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + self.query_start_loc[num_reqs + 1:].fill_(-1) + attn_metadata = self.attn_metadata_builder.build( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, @@ -1376,6 +1402,7 @@ def _get_prompt_logprobs_dict( def _dummy_run( self, num_tokens: int, + initialize_attention_metadata: bool = False, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs @@ -1392,6 +1419,16 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + if initialize_attention_metadata: + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_tokens, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + ) + else: + attn_metadata = None + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1420,7 +1457,7 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, + with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens): hidden_states = model( @@ -1587,7 +1624,7 @@ def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please add " - "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) + "-O %s or -O %s to use CUDA graphs.", CompilationLevel.PIECEWISE, CompilationLevel.FULL_GRAPH) return start_time = time.perf_counter() @@ -1600,8 +1637,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, initialize_attention_metadata=True) + self._dummy_run(num_tokens, initialize_attention_metadata=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] From 6aa65db7925f63f30aceabf67ed0a33c29cf3d31 Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Mon, 7 Apr 2025 20:57:02 +0000 Subject: [PATCH 02/13] Linting and such Signed-off-by: Chanh Nguyen --- requirements/test.txt | 22 +++++++++++++++++++--- vllm/config.py | 5 ++--- vllm/v1/attention/backends/flash_attn.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 6 ++++-- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index 236b8be32805..8dde94f313c8 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -23,6 +23,10 @@ anyio==4.6.2.post1 # via httpx argcomplete==3.5.1 # via datamodel-code-generator +async-timeout==5.0.1 + # via + # aiohttp + # redis attrs==24.2.0 # via # aiohttp @@ -117,6 +121,10 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval +exceptiongroup==1.2.2 + # via + # anyio + # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -556,9 +564,7 @@ sentence-transformers==3.2.1 sentencepiece==0.2.0 # via mistral-common setuptools==75.8.0 - # via - # pytablewriter - # torch + # via pytablewriter shellingham==1.5.4 # via typer six==1.16.0 @@ -605,6 +611,12 @@ timm==1.0.11 # via -r requirements/test.in tokenizers==0.21.0 # via transformers +toml==0.10.2 + # via datamodel-code-generator +tomli==2.2.1 + # via + # black + # pytest torch==2.6.0 # via # -r requirements/test.in @@ -670,12 +682,16 @@ typer==0.15.2 # via fastsafetensors typing-extensions==4.12.2 # via + # anyio + # black # huggingface-hub # librosa # mistral-common + # multidict # pqdm # pydantic # pydantic-core + # rich # torch # typer tzdata==2024.2 diff --git a/vllm/config.py b/vllm/config.py index b7319605eb8f..155879de8d6a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3372,15 +3372,14 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called - if not self.splitting_ops: - if self.level == CompilationLevel.PIECEWISE: + if not self.splitting_ops: + if self.level == CompilationLevel.PIECEWISE: self.splitting_ops = [ "vllm.unified_attention", "vllm.unified_attention_with_output", ] elif self.level == CompilationLevel.FULL_GRAPH: self.splitting_ops = [] - @dataclass diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 04112be40069..5a0668970792 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -109,9 +109,8 @@ def reorder_batch(self, input_batch: "InputBatch", def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc = self.runner.query_start_loc[:num_reqs + 1] + query_start_loc = self.runner.query_start_loc[:num_reqs + 1] seq_lens = self.runner.seq_lens[:num_reqs] - block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) @@ -151,6 +150,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, ) return attn_metadata + class FlashAttentionImpl(AttentionImpl): def __init__( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 045f81ad2644..11182dae2764 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1624,7 +1624,8 @@ def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please add " - "-O %s or -O %s to use CUDA graphs.", CompilationLevel.PIECEWISE, CompilationLevel.FULL_GRAPH) + "-O %s or -O %s to use CUDA graphs.", + CompilationLevel.PIECEWISE, CompilationLevel.FULL_GRAPH) return start_time = time.perf_counter() @@ -1637,7 +1638,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, initialize_attention_metadata=True) + self._dummy_run(num_tokens, + initialize_attention_metadata=True) self._dummy_run(num_tokens, initialize_attention_metadata=True) end_time = time.perf_counter() From c338c99bab846fec12d71c16477f16416d22be7a Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Tue, 8 Apr 2025 06:26:59 +0000 Subject: [PATCH 03/13] Add unit test Signed-off-by: Chanh Nguyen --- .../compile/piecewise/test_full_cudagraph.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/compile/piecewise/test_full_cudagraph.py diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py new file mode 100644 index 000000000000..e0f1cc20c27c --- /dev/null +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig, CompilationLevel + + +def run_model(compilation_config: CompilationConfig): + prompts = ["Hello, my name is"] + sampling_params = SamplingParams(temperature=0.0, max_tokens=20) + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + compilation_config=compilation_config) + + return llm.generate(prompts, sampling_params) + + +def test_full_cudagraph(monkeypatch): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_FLASH_ATTN_VERSION", "3") + + full_cudagraph_responses = run_model( + compilation_config=CompilationConfig( + level=CompilationLevel.FULL_GRAPH, + use_cudagraph=True, + )) + + piecewise_responses = run_model(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=False, + )) + + assert full_cudagraph_responses[0].outputs[ + 0].text == piecewise_responses[0].outputs[0].text From 15edc9bf26709f090524afb0aa6b62a77e4f8b3d Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Tue, 8 Apr 2025 07:38:37 +0000 Subject: [PATCH 04/13] Add unit test Signed-off-by: Chanh Nguyen --- tests/compile/piecewise/test_full_cudagraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index e0f1cc20c27c..549fa6604532 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -26,7 +26,7 @@ def test_full_cudagraph(monkeypatch): piecewise_responses = run_model(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, - use_cudagraph=False, + use_cudagraph=True, )) assert full_cudagraph_responses[0].outputs[ From 4700afd50b8a2bf197ff925c1401b0e3352c82d4 Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Fri, 11 Apr 2025 22:16:48 +0000 Subject: [PATCH 05/13] Took suggestions Signed-off-by: Chanh Nguyen --- .../compile/piecewise/test_full_cudagraph.py | 101 ++++++++++++++---- vllm/config.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 3 +- 3 files changed, 83 insertions(+), 23 deletions(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 549fa6604532..9ce92060b17d 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -1,33 +1,92 @@ # SPDX-License-Identifier: Apache-2.0 +import contextlib +import os + +import pytest + from vllm import LLM, SamplingParams from vllm.config import CompilationConfig, CompilationLevel +MODEL = "Qwen/Qwen2-1.5B-Instruct" -def run_model(compilation_config: CompilationConfig): - prompts = ["Hello, my name is"] - sampling_params = SamplingParams(temperature=0.0, max_tokens=20) - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - compilation_config=compilation_config) +@contextlib.contextmanager +def temporary_environ(env_vars): + """ + Temporarily set environment variables and restore them afterward. + We have to do this vs monkeypatch because monkeypatch doesn't work + with "module" scoped fixtures. + """ + original_env = {k: os.environ.get(k) for k in env_vars} + try: + os.environ.update(env_vars) + yield + finally: + for k, v in original_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v - return llm.generate(prompts, sampling_params) +@pytest.fixture(scope="module") +def full_cudagraph_llm(): + + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "3" + }): + return LLM(model=MODEL, + gpu_memory_utilization=0.2, + compilation_config=CompilationConfig( + level=CompilationLevel.FULL_GRAPH, + use_cudagraph=True, + )) + + +@pytest.fixture(scope="module") +def piecewise_llm(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "3" + }): + return LLM(model=MODEL, + gpu_memory_utilization=0.5, + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + )) + + +def generate_text(llm: LLM, batch_size: int, max_tokens: int): + prompts = ["Hello, my name is"] * batch_size + sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + + return llm.generate(prompts, sampling_params) -def test_full_cudagraph(monkeypatch): - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - m.setenv("VLLM_FLASH_ATTN_VERSION", "3") - full_cudagraph_responses = run_model( - compilation_config=CompilationConfig( - level=CompilationLevel.FULL_GRAPH, - use_cudagraph=True, - )) +@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), + (16, 10), (25, 10), + (32, 10), (45, 10), + (64, 10), (25, 5), + (25, 20)]) +def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, + piecewise_llm): + """ + Load full cudagraph model and piecewise model once, and at the same time to + reuse them across various test cases. - piecewise_responses = run_model(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - )) + Test various batch sizes and max_tokens to ensure that the full cudagraph + compilation works for padded cases too. + """ + piecewise_responses = generate_text(piecewise_llm, + batch_size=batch_size, + max_tokens=max_tokens) + full_cudagraph_responses = generate_text(full_cudagraph_llm, + batch_size=batch_size, + max_tokens=max_tokens) - assert full_cudagraph_responses[0].outputs[ - 0].text == piecewise_responses[0].outputs[0].text + # Check that all responses are the same + for i in range(len(piecewise_responses)): + assert piecewise_responses[i].outputs[ + 0].text == full_cudagraph_responses[i].outputs[0].text diff --git a/vllm/config.py b/vllm/config.py index ece6a43d6504..ade938ad128e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3188,7 +3188,7 @@ class CompilationConfig(BaseModel): We use string to avoid serialization issues when using compilation in a distributed setting. When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). - When the compilation level is 4, the backend is used for the full graph. + When the compilation level is 4, the backend is used for the full graph. This improves performance for smaller models. - custom_ops: fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op names to enable (prefixed with a '+'), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b1835d9a2544..3cbab4e54c60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1643,7 +1643,8 @@ def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please add " - "-O %s or -O %s to use CUDA graphs.", + "-O %s for piecewise CUDA graphs (attention is skipped) or " + "-O %s for full CUDA graphs (attention included).", CompilationLevel.PIECEWISE, CompilationLevel.FULL_GRAPH) return From 353ea66d80e6c7a2ed07d8cd907fbb1a54e14d2f Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Tue, 15 Apr 2025 00:25:32 +0000 Subject: [PATCH 06/13] Revert requirements/test.txt Signed-off-by: Chanh Nguyen --- requirements/test.txt | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index bad69095531f..476b4a2cc0ec 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -23,10 +23,6 @@ anyio==4.6.2.post1 # via httpx argcomplete==3.5.1 # via datamodel-code-generator -async-timeout==5.0.1 - # via - # aiohttp - # redis attrs==24.2.0 # via # aiohttp @@ -123,10 +119,6 @@ encodec==0.1.1 # via vocos evaluate==0.4.3 # via lm-eval -exceptiongroup==1.2.2 - # via - # anyio - # pytest fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -571,7 +563,9 @@ sentence-transformers==3.2.1 sentencepiece==0.2.0 # via mistral-common setuptools==75.8.0 - # via pytablewriter + # via + # pytablewriter + # torch shellingham==1.5.4 # via typer six==1.16.0 @@ -618,12 +612,6 @@ timm==1.0.11 # via -r requirements/test.in tokenizers==0.21.0 # via transformers -toml==0.10.2 - # via datamodel-code-generator -tomli==2.2.1 - # via - # black - # pytest torch==2.6.0 # via # -r requirements/test.in @@ -689,16 +677,12 @@ typer==0.15.2 # via fastsafetensors typing-extensions==4.12.2 # via - # anyio - # black # huggingface-hub # librosa # mistral-common - # multidict # pqdm # pydantic # pydantic-core - # rich # torch # typer tzdata==2024.2 From 1dcdb375c1148892907bc1b1dbfcef6787a12f2e Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Fri, 25 Apr 2025 10:17:49 +0000 Subject: [PATCH 07/13] Responded to comments Signed-off-by: Chanh Nguyen --- docs/source/design/v1/torch_compile.md | 6 ++++++ vllm/config.py | 24 +++++++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 25 +++++++++++++++++-------- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 57dba680b97c..1c6f2a674cae 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -137,3 +137,9 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You `VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. + +### Full Cudagraph capture + +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. + +Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. diff --git a/vllm/config.py b/vllm/config.py index 0c91a0e4202a..6e207c7617a6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3553,14 +3553,13 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called - if not self.splitting_ops: - if self.level == CompilationLevel.PIECEWISE: - self.splitting_ops = [ - "vllm.unified_attention", - "vllm.unified_attention_with_output", - ] - elif self.level == CompilationLevel.FULL_GRAPH: - self.splitting_ops = [] + if self.level == CompilationLevel.PIECEWISE: + self.splitting_ops = [ + "vllm.unified_attention", + "vllm.unified_attention_with_output", + ] + else: + assert not self.splitting_ops @dataclass @@ -3787,7 +3786,8 @@ def __post_init__(self): self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_noop = False - if self.compilation_config.level < CompilationLevel.PIECEWISE: + # Default to PIECEWISE except for when FULL_GRAPH is desired. + if self.compilation_config.level != CompilationLevel.FULL_GRAPH: self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() @@ -3810,6 +3810,12 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.compilation_config.level == CompilationLevel.FULL_GRAPH and \ + not self.model_config.disable_cascade_attn: + logger.warning_once( + "CompilationLevel.FULL_GRAPH (-O4) is not supported with " + "cascade attention. Disabling cascade attention.") + self.model_config.disable_cascade_attn = True if self.model_config and self.model_config.use_mla and \ not (current_platform.is_cuda() or current_platform.is_rocm()): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ae70229ad91a..87d737d08a55 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -43,6 +43,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -135,6 +136,15 @@ def __init__( raise NotImplementedError( "Non-Attention backend is not supported by V1 GPUModelRunner.") + if self.vllm_config.compilation_config.level == CompilationLevel.FULL_GRAPH: # noqa: E501 + attn_backend_name = self.attn_backend.__name__ + flash_attn_version = get_flash_attn_version() + assert attn_backend_name == "FlashAttentionBackend" and \ + flash_attn_version == 3, \ + (f"CompilationLevel.FULL_GRAPH (-O4) is only supported with " + f"FA3. Current attention backend is {attn_backend_name} and " + f"FlashAttention version is {flash_attn_version}.") + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn @@ -1417,7 +1427,7 @@ def _get_prompt_logprobs_dict( def _dummy_run( self, num_tokens: int, - initialize_attention_metadata: bool = False, + skip_attn: bool = False, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs @@ -1434,15 +1444,15 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - if initialize_attention_metadata: + if skip_attn: + attn_metadata = None + else: attn_metadata = self.attn_metadata_builder.build( num_reqs=num_tokens, num_actual_tokens=num_tokens, max_query_len=num_tokens, common_prefix_len=0, ) - else: - attn_metadata = None with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -1625,7 +1635,7 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - hidden_states = self._dummy_run(self.max_num_tokens) + hidden_states = self._dummy_run(self.max_num_tokens, skip_attn=True) if get_pp_group().is_last_rank: sampler_output = self._dummy_sampler_run(hidden_states) else: @@ -1654,9 +1664,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - initialize_attention_metadata=True) - self._dummy_run(num_tokens, initialize_attention_metadata=True) + self._dummy_run(num_tokens) + self._dummy_run(num_tokens) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] From 7e821e0830d671402d84a377d0dbdcce3eba7a93 Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Wed, 30 Apr 2025 10:33:42 +0000 Subject: [PATCH 08/13] Updated UI Signed-off-by: Chanh Nguyen --- .../compile/piecewise/test_full_cudagraph.py | 39 ++++++++++--------- vllm/compilation/backends.py | 4 +- vllm/config.py | 26 +++++++------ vllm/v1/worker/gpu_model_runner.py | 19 +++++---- 4 files changed, 47 insertions(+), 41 deletions(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 9ce92060b17d..112983aa91e7 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -5,7 +5,7 @@ import pytest from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig MODEL = "Qwen/Qwen2-1.5B-Instruct" @@ -31,17 +31,13 @@ def temporary_environ(env_vars): @pytest.fixture(scope="module") def full_cudagraph_llm(): - with temporary_environ({ "VLLM_USE_V1": "1", "VLLM_FLASH_ATTN_VERSION": "3" }): return LLM(model=MODEL, gpu_memory_utilization=0.2, - compilation_config=CompilationConfig( - level=CompilationLevel.FULL_GRAPH, - use_cudagraph=True, - )) + compilation_config=CompilationConfig(full_cuda_graph=True)) @pytest.fixture(scope="module") @@ -52,26 +48,18 @@ def piecewise_llm(): }): return LLM(model=MODEL, gpu_memory_utilization=0.5, - compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, - use_cudagraph=True, - )) + compilation_config=CompilationConfig()) def generate_text(llm: LLM, batch_size: int, max_tokens: int): - prompts = ["Hello, my name is"] * batch_size + prompts = ["I pledge allegiance to"] * batch_size sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) return llm.generate(prompts, sampling_params) -@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), - (16, 10), (25, 10), - (32, 10), (45, 10), - (64, 10), (25, 5), - (25, 20)]) -def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, - piecewise_llm): +@pytest.mark.parametrize("batch_size", [1, 7, 16, 25, 32, 45, 64]) +def test_full_cudagraph(batch_size, full_cudagraph_llm, piecewise_llm): """ Load full cudagraph model and piecewise model once, and at the same time to reuse them across various test cases. @@ -79,6 +67,11 @@ def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, Test various batch sizes and max_tokens to ensure that the full cudagraph compilation works for padded cases too. """ + + # For batch size > 1, PyTorch is not always deterministic so keep the + # output short and use a predictable prompt such as "I pledge allegiance to" + # See https://github.com/vllm-project/vllm/issues/5898 + max_tokens = 5 piecewise_responses = generate_text(piecewise_llm, batch_size=batch_size, max_tokens=max_tokens) @@ -90,3 +83,13 @@ def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, for i in range(len(piecewise_responses)): assert piecewise_responses[i].outputs[ 0].text == full_cudagraph_responses[i].outputs[0].text + + +def test_full_cudagraph_with_invalid_backend(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": + "2" #FA2 not supported with full_cuda_graph + }), pytest.raises(RuntimeError): + LLM(model=MODEL, + compilation_config=CompilationConfig(full_cuda_graph=True)) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4c24f393de78..7012131d0532 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -283,8 +283,8 @@ def call_module(self, target: torch.fx.node.Target, class VllmBackend: """The compilation backend for `torch.compile` with vLLM. - It is used for compilation level of `CompilationLevel.PIECEWISE` or - `CompilationLevel.FULL_GRAPH`, where we customize the compilation. + It is used for compilation level of `CompilationLevel.PIECEWISE`, + where we customize the compilation. The major work of this backend is to split the graph into piecewise graphs, and pass them to the piecewise backend. diff --git a/vllm/config.py b/vllm/config.py index a61a66276273..a0014223d54e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3370,7 +3370,6 @@ class CompilationLevel: DYNAMO_AS_IS = 1 DYNAMO_ONCE = 2 PIECEWISE = 3 - FULL_GRAPH = 4 class CompilationConfig(BaseModel): @@ -3395,7 +3394,6 @@ class CompilationConfig(BaseModel): We use string to avoid serialization issues when using compilation in a distributed setting. When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). - When the compilation level is 4, the backend is used for the full graph. This improves performance for smaller models. - custom_ops: fine-grained control over which custom ops to enable/disable. Use 'all' to enable all, 'none' to disable all. Also specify a list of custom op names to enable (prefixed with a '+'), @@ -3428,6 +3426,10 @@ class CompilationConfig(BaseModel): are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an internally managed buffer. Default is False. + - full_cuda_graph: whether to use a full cuda graph for the entire forward + pass rather than splitting certain operations such as attention into subgraphs. + Thus this flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models. - Inductor compilation: - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. @@ -3472,6 +3474,7 @@ class CompilationConfig(BaseModel): cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[list[int]] = None cudagraph_copy_inputs: bool = False + full_cuda_graph: bool = False class PassConfig(BaseModel): """ @@ -3695,13 +3698,16 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called - if self.level == CompilationLevel.PIECEWISE: - self.splitting_ops = [ + if self.splitting_ops and self.full_cuda_graph: + raise ValueError("full_cuda_graph cannot be used together with " + "splitting_ops, as Full CUDA graph will override " + f"the splitting_ops: {self.splitting_ops}") + + if not self.splitting_ops: + self.splitting_ops = [] if self.full_cuda_graph else [ "vllm.unified_attention", "vllm.unified_attention_with_output", ] - else: - assert not self.splitting_ops @dataclass @@ -3940,9 +3946,7 @@ def __post_init__(self): self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_noop = False - # Default to PIECEWISE except for when FULL_GRAPH is desired. - if self.compilation_config.level != CompilationLevel.FULL_GRAPH: - self.compilation_config.level = CompilationLevel.PIECEWISE + self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() if self.parallel_config is not None and \ @@ -3976,10 +3980,10 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.compilation_config.level == CompilationLevel.FULL_GRAPH and \ + if self.compilation_config.full_cuda_graph and \ not self.model_config.disable_cascade_attn: logger.warning_once( - "CompilationLevel.FULL_GRAPH (-O4) is not supported with " + "full_cuda_graph is not supported with " "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fced13be565a..8dbbd4943e24 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,6 +12,7 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -46,7 +47,6 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -139,14 +139,15 @@ def __init__( raise NotImplementedError( "Non-Attention backend is not supported by V1 GPUModelRunner.") - if self.vllm_config.compilation_config.level == CompilationLevel.FULL_GRAPH: # noqa: E501 + if self.vllm_config.compilation_config.full_cuda_graph: attn_backend_name = self.attn_backend.__name__ flash_attn_version = get_flash_attn_version() - assert attn_backend_name == "FlashAttentionBackend" and \ - flash_attn_version == 3, \ - (f"CompilationLevel.FULL_GRAPH (-O4) is only supported with " - f"FA3. Current attention backend is {attn_backend_name} and " - f"FlashAttention version is {flash_attn_version}.") + if attn_backend_name != "FlashAttentionBackend" or \ + flash_attn_version != 3: + raise ValueError( + f"full_cuda_graph is only supported with " + f"FA3. Current attention backend is {attn_backend_name}, " + f"FlashAttention version is {flash_attn_version}.") self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) @@ -1710,9 +1711,7 @@ def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please add " - "-O %s for piecewise CUDA graphs (attention is skipped) or " - "-O %s for full CUDA graphs (attention included).", - CompilationLevel.PIECEWISE, CompilationLevel.FULL_GRAPH) + "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) return start_time = time.perf_counter() From 0d1a7967bc96d477b5c3186acc87bb1634c39ffe Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Wed, 30 Apr 2025 10:41:00 +0000 Subject: [PATCH 09/13] minor fixes Signed-off-by: Chanh Nguyen --- docs/source/design/v1/torch_compile.md | 2 +- vllm/compilation/monitor.py | 4 ++-- vllm/config.py | 7 +++---- vllm/v1/worker/gpu_model_runner.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 9e4825932dd3..f024c96781fe 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -140,6 +140,6 @@ Then it will only capture cudagraph for the specified sizes. It can be useful to ### Full Cudagraph capture -It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config {"'full_cuda_graph': True"}` Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 1334680d7386..786c7c1e1859 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -17,7 +17,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): torch_compile_start_time = time.time() compilation_config: CompilationConfig = vllm_config.compilation_config - if compilation_config.level >= CompilationLevel.PIECEWISE and \ + if compilation_config.level == CompilationLevel.PIECEWISE and \ compilation_config.debug_dump_path: import depyf path = os.path.join(compilation_config.debug_dump_path, @@ -29,7 +29,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): def end_monitoring_torch_compile(vllm_config: VllmConfig): compilation_config: CompilationConfig = vllm_config.compilation_config - if compilation_config.level >= CompilationLevel.PIECEWISE: + if compilation_config.level == CompilationLevel.PIECEWISE: logger.info("torch.compile takes %.2f s in total", compilation_config.compilation_time) global context_manager diff --git a/vllm/config.py b/vllm/config.py index a0014223d54e..bf4e52a61dc4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3382,7 +3382,6 @@ class CompilationConfig(BaseModel): - 1: dynamo as is. - 2: dynamo once. - 3: piecewise compilation. - - 4: full compilation. - debug_dump_path: the path to dump the debug information. - cache_dir: the directory to store the compiled graph, to accelerate Inductor compilation. By default, it will use @@ -3575,7 +3574,7 @@ def __repr__(self) -> str: @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" - if cli_value in ["0", "1", "2", "3", "4"]: + if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) # do not use `eval`, it is dangerous and can execute arbitrary code dict_value = ast.literal_eval(cli_value) @@ -3642,7 +3641,7 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: # TODO: pass user-specified backend to piecewise compilation # merge with the config use_inductor - assert self.level >= CompilationLevel.PIECEWISE + assert self.level == CompilationLevel.PIECEWISE from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) @@ -4167,7 +4166,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): logger.debug("disabled custom ops: %s", vllm_config.compilation_config.disabled_custom_ops) if check_compile and \ - vllm_config.compilation_config.level >= CompilationLevel.PIECEWISE \ + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ and compilation_counter.num_models_seen == num_models_seen: # If the model supports compilation, # compilation_counter.num_models_seen should be increased diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8dbbd4943e24..6742ca2d9b19 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -205,7 +205,7 @@ def __init__( ) self.use_cuda_graph = (self.vllm_config.compilation_config.level - >= CompilationLevel.PIECEWISE + == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. From 166e6a6b729adb9cfaf55ba1d9b352eb4d6c53b2 Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Wed, 30 Apr 2025 17:54:26 +0000 Subject: [PATCH 10/13] Make test deterministic Signed-off-by: Chanh Nguyen --- .../compile/piecewise/test_full_cudagraph.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 112983aa91e7..a71a40cda73e 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -52,14 +52,21 @@ def piecewise_llm(): def generate_text(llm: LLM, batch_size: int, max_tokens: int): - prompts = ["I pledge allegiance to"] * batch_size - sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + prompts = ["Hi my name is"] * batch_size + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + top_p=0.95) return llm.generate(prompts, sampling_params) -@pytest.mark.parametrize("batch_size", [1, 7, 16, 25, 32, 45, 64]) -def test_full_cudagraph(batch_size, full_cudagraph_llm, piecewise_llm): +@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), + (16, 10), (25, 10), + (32, 10), (45, 10), + (64, 10), (8, 5), + (8, 20), (8, 200)]) +def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, + piecewise_llm): """ Load full cudagraph model and piecewise model once, and at the same time to reuse them across various test cases. @@ -67,11 +74,6 @@ def test_full_cudagraph(batch_size, full_cudagraph_llm, piecewise_llm): Test various batch sizes and max_tokens to ensure that the full cudagraph compilation works for padded cases too. """ - - # For batch size > 1, PyTorch is not always deterministic so keep the - # output short and use a predictable prompt such as "I pledge allegiance to" - # See https://github.com/vllm-project/vllm/issues/5898 - max_tokens = 5 piecewise_responses = generate_text(piecewise_llm, batch_size=batch_size, max_tokens=max_tokens) From 6b523ac1e152e610393250ca9d2ae87f22259122 Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Wed, 7 May 2025 05:18:25 +0000 Subject: [PATCH 11/13] Updated per comments Signed-off-by: Chanh Nguyen --- docs/source/design/v1/torch_compile.md | 2 +- vllm/v1/worker/gpu_model_runner.py | 27 +++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index f024c96781fe..4d8ce0fd9227 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -140,6 +140,6 @@ Then it will only capture cudagraph for the specified sizes. It can be useful to ### Full Cudagraph capture -It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config {"'full_cuda_graph': True"}` +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"` Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a13c80196aef..bd8c87fd9efc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -236,12 +236,9 @@ def __init__( self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) - self.slot_mapping = torch.zeros( - self.max_num_tokens, - # CPU slot_mapping is int32, but - # this one must be int64 - dtype=torch.int64, - device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -295,7 +292,7 @@ def __init__( pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, + dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() @@ -1514,7 +1511,7 @@ def _get_prompt_logprobs_dict( def _dummy_run( self, num_tokens: int, - skip_attn: bool = False, + skip_attn: bool = True, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs @@ -1534,11 +1531,18 @@ def _dummy_run( if skip_attn: attn_metadata = None else: + query_start_loc = self.query_start_loc[:num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) + attn_metadata = self.attn_metadata_builder.build( num_reqs=num_tokens, num_actual_tokens=num_tokens, max_query_len=num_tokens, common_prefix_len=0, + common_attn_metadata=common_attn_metadata, ) with self.maybe_dummy_run_with_lora(self.lora_config, @@ -1731,7 +1735,7 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - hidden_states = self._dummy_run(self.max_num_tokens, skip_attn=True) + hidden_states = self._dummy_run(self.max_num_tokens) if get_pp_group().is_last_rank: sampler_output = self._dummy_sampler_run(hidden_states) else: @@ -1755,11 +1759,12 @@ def capture_model(self) -> None: # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): + skip_attn = not self.vllm_config.compilation_config.full_cuda_graph for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, skip_attn=skip_attn) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] From 59e52e6abacd821b375b3e7f9ebc9ea472351e77 Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Wed, 7 May 2025 20:45:18 +0000 Subject: [PATCH 12/13] Disable aot_schedule Signed-off-by: Chanh Nguyen --- vllm/v1/attention/backends/flash_attn.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b7f313ee74f1..7af80b78f0b0 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -291,6 +291,7 @@ class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): model_config = runner.model_config + compilation_config = runner.vllm_config.compilation_config self.runner = runner self.num_heads_q = model_config.get_num_attention_heads( @@ -300,7 +301,13 @@ def __init__(self, runner: "GPUModelRunner"): self.headdim = model_config.get_head_size() self.page_size = self.runner.block_size - self.aot_schedule = (get_flash_attn_version() == 3) + if get_flash_attn_version() == 3: + # TODO(cnguyen): Support AOT scheduler with full CUDA graph + self.aot_schedule = not compilation_config.full_cuda_graph + if not self.aot_schedule: + logger.warning( + "AOT Scheduler is disabled when using full_cuda_graph") + # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None From 22fa9df3a00cb218c7542f42ac49f16aaf4fbcff Mon Sep 17 00:00:00 2001 From: Chanh Nguyen Date: Wed, 7 May 2025 22:34:59 +0000 Subject: [PATCH 13/13] fix Signed-off-by: Chanh Nguyen --- vllm/v1/attention/backends/flash_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 7af80b78f0b0..605dff3749fb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -302,11 +302,12 @@ def __init__(self, runner: "GPUModelRunner"): self.page_size = self.runner.block_size if get_flash_attn_version() == 3: - # TODO(cnguyen): Support AOT scheduler with full CUDA graph self.aot_schedule = not compilation_config.full_cuda_graph if not self.aot_schedule: logger.warning( - "AOT Scheduler is disabled when using full_cuda_graph") + "AOT Schedule is disabled when using full_cuda_graph") + else: + self.aot_schedule = False # Sliding window size to be used with the AOT scheduler will be # populated on first build() call.