From 499e424063865c34ee7f0c18fe2166c664703764 Mon Sep 17 00:00:00 2001 From: yuz207 Date: Tue, 14 Oct 2025 22:29:11 +0000 Subject: [PATCH 1/9] refactor(worker): simplify mask_work computation in GPUModelRunner Replaced explicit conditional logic and nonzero indexing with a cumulative product approach to compute mask_work. This change streamlines the code for better readability and maintainability without altering functionality. --- vllm/v1/worker/gpu_model_runner.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0909d0f8dd0a..a0d32fd5a20c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2336,16 +2336,9 @@ def _build_nwor_acceptance_mask( row = row.to(dtype=draft_ids.dtype) draft_slice = draft_ids[start:end] - comparison = (row == draft_slice).flatten() - - if bool(comparison.all().item()): - accepted = draft_count - else: - reject = torch.nonzero(~comparison, as_tuple=False) - accepted = int(reject[0, 0].item()) if reject.numel() > 0 else draft_count - - if accepted > 0: - mask_work[start : start + accepted] = True + comparison = (row == draft_slice) + prefix = torch.cumprod(comparison.to(torch.int32), dim=0) + mask_work[start:end] = prefix.to(torch.bool) start = end if start != total_tokens: From 48e0ea7f526d0bc4bf0c5a1fc68b6b47ec979b20 Mon Sep 17 00:00:00 2001 From: yuz207 Date: Tue, 14 Oct 2025 22:43:34 +0000 Subject: [PATCH 2/9] feat: add SCV mode scaffolding --- vllm/envs.py | 3 +++ vllm/v1/worker/gpu_model_runner.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index f876a0765496..5336660dd1be 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -200,6 +200,7 @@ VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_DISABLE_NWOR: bool = False VLLM_NWOR_MODE: str = "stage" + VLLM_SCV_MODE: str = "off" VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False @@ -1315,6 +1316,8 @@ def get_vllm_port() -> int | None: "VLLM_DISABLE_NWOR": lambda: bool(int(os.getenv("VLLM_DISABLE_NWOR", "0"))), # Select NWOR mode: "stage" (default) or "immediate" to bypass staging. "VLLM_NWOR_MODE": lambda: os.getenv("VLLM_NWOR_MODE", "stage"), + # Speculative chunk verify mode: "off" (default), "graph", or "adaptive". + "VLLM_SCV_MODE": lambda: os.getenv("VLLM_SCV_MODE", "off"), # Used to force set up loopback IP "VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""), # Used to set the process name prefix for vLLM processes. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a0d32fd5a20c..9ad5c21537ba 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -509,6 +509,13 @@ def __init__( # Cached outputs. self._deferred_write_manager = DeferredWriteManager(mode=envs.VLLM_NWOR_MODE) self._latest_nwor_window_metrics: dict[str, int | str] | None = None + self._scv_mode = envs.VLLM_SCV_MODE.lower() + + def _scv_enabled(self) -> bool: + if self._scv_mode not in ("off", "graph", "adaptive"): + logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode) + self._scv_mode = "off" + return self._scv_mode != "off" self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( From 2999f91764e11316c3fb2511fc340c748b0a7e6f Mon Sep 17 00:00:00 2001 From: yuz207 Date: Tue, 14 Oct 2025 22:51:54 +0000 Subject: [PATCH 3/9] feat(gpu_model_runner): add SCV mode with CUDA graph optimized mask computation Introduced SCV (Speculative Computation Vectorization) mode to GPUModelRunner to optimize mask computation during decoding. Added SCVGraphExecutor and _SCVGraphEntry classes leveraging CUDA Graphs for efficient repeated mask calculations. The SCV mode supports 'graph' and 'adaptive' operation and falls back gracefully if CUDA graph execution fails. This enhancement improves decoding performance by reusing captured CUDA graphs for mask operations in speculative decoding workflows. Co-authored-by: terragon-labs[bot] --- vllm/v1/worker/gpu_model_runner.py | 212 +++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9ad5c21537ba..fae0fbc9ab9a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,6 +5,7 @@ import itertools import time from collections import defaultdict +from dataclasses import dataclass from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy @@ -510,6 +511,7 @@ def __init__( self._deferred_write_manager = DeferredWriteManager(mode=envs.VLLM_NWOR_MODE) self._latest_nwor_window_metrics: dict[str, int | str] | None = None self._scv_mode = envs.VLLM_SCV_MODE.lower() + self._scv_graph_executor: SCVGraphExecutor | None = None def _scv_enabled(self) -> bool: if self._scv_mode not in ("off", "graph", "adaptive"): @@ -2323,6 +2325,15 @@ def _build_nwor_acceptance_mask( target_device = spec_decode_metadata.draft_token_ids.device work_device = sampled_token_ids.device + if self._scv_enabled(): + mask = self._scv_vectorized_mask( + spec_decode_metadata, sampled_token_ids, total_tokens, work_device + ) + if mask is not None: + if mask.device != target_device: + mask = mask.to(device=target_device) + return mask + draft_ids = spec_decode_metadata.draft_token_ids if draft_ids.device != work_device: draft_ids = draft_ids.to(device=work_device) @@ -2355,6 +2366,85 @@ def _build_nwor_acceptance_mask( return mask_work return mask_work.to(device=target_device) + def _scv_vectorized_mask( + self, + spec_decode_metadata: SpecDecodeMetadata, + sampled_token_ids: torch.Tensor, + total_tokens: int, + device: torch.device, + ) -> torch.Tensor | None: + draft_ids = spec_decode_metadata.draft_token_ids + max_spec_len = spec_decode_metadata.max_spec_len + num_draft_tensor = torch.tensor( + spec_decode_metadata.num_draft_tokens, + device=device, + dtype=torch.int32, + ) + if draft_ids.device != device: + draft_ids = draft_ids.to(device=device) + + cu = spec_decode_metadata.cu_num_draft_tokens.to(device=device) + + if self._scv_mode == "graph": + if self._scv_graph_executor is None: + self._scv_graph_executor = SCVGraphExecutor(device) + mask = self._scv_graph_executor.run( + spec_decode_metadata, sampled_token_ids, total_tokens + ) + if mask is not None: + return mask + + mask = self._scv_compute_mask( + draft_ids, + num_draft_tensor, + cu, + sampled_token_ids, + max_spec_len, + total_tokens, + ) + return mask + + @staticmethod + def _scv_compute_mask( + draft_ids: torch.Tensor, + num_draft_tokens: torch.Tensor, + cu_num_draft_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + max_spec_len: int, + total_tokens: int, + ) -> torch.Tensor: + device = draft_ids.device + indices = torch.arange(total_tokens, device=device, dtype=torch.int32) + req_idx = torch.bucketize(indices, cu_num_draft_tokens) + prev_cu = torch.cat([cu_num_draft_tokens.new_zeros(1), cu_num_draft_tokens[:-1]]) + pos_in_req = indices - prev_cu[req_idx] + + gathered = sampled_token_ids[req_idx, pos_in_req] + comparison = gathered == draft_ids + + max_val = max_spec_len + 1 + values = torch.where( + ~comparison, + (pos_in_req + 1).to(torch.int32), + torch.full_like(pos_in_req, max_val, dtype=torch.int32), + ) + + accepted = torch.full( + (num_draft_tokens.numel(),), + max_val, + device=device, + dtype=torch.int32, + ) + accepted.scatter_reduce_(0, req_idx, values, reduce="amin") + accepted = torch.where( + accepted == max_val, + num_draft_tokens, + accepted - 1, + ) + accepted_broadcast = accepted[req_idx] + mask_flat = pos_in_req < accepted_broadcast + return mask_flat + def _bookkeeping_sync( self, scheduler_output: "SchedulerOutput", @@ -4836,3 +4926,125 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: self.transfer_event.record() self.transfer_event.synchronize() return pinned.tolist() +@dataclass +class _SCVGraphEntry: + num_reqs: int + max_spec_len: int + total_tokens: int + sampled_shape: tuple[int, int] + sampled_dtype: torch.dtype + draft_dtype: torch.dtype + device: torch.device + + def __post_init__(self): + self.sampled_buffer = torch.empty( + self.sampled_shape, device=self.device, dtype=self.sampled_dtype + ) + self.draft_buffer = torch.empty( + (self.total_tokens,), device=self.device, dtype=self.draft_dtype + ) + self.num_tokens_buffer = torch.empty( + (self.num_reqs,), device=self.device, dtype=torch.int32 + ) + self.cu_buffer = torch.empty( + (self.num_reqs,), device=self.device, dtype=torch.int32 + ) + self.mask_buffer = torch.empty( + (self.total_tokens,), device=self.device, dtype=torch.bool + ) + self.graph = torch.cuda.CUDAGraph() + self._captured = False + + def capture(self): + if self._captured: + return + mask = GPUModelRunner._scv_compute_mask( + self.draft_buffer, + self.num_tokens_buffer, + self.cu_buffer, + self.sampled_buffer, + self.max_spec_len, + self.total_tokens, + ) + self.mask_buffer.copy_(mask) + torch.cuda.synchronize() + with torch.cuda.graph(self.graph): + mask = GPUModelRunner._scv_compute_mask( + self.draft_buffer, + self.num_tokens_buffer, + self.cu_buffer, + self.sampled_buffer, + self.max_spec_len, + self.total_tokens, + ) + self.mask_buffer.copy_(mask) + self._captured = True + + def run(self): + if not self._captured: + self.capture() + self.graph.replay() + return self.mask_buffer + + +class SCVGraphExecutor: + def __init__(self, device: torch.device): + self.device = device + self.entries: dict[tuple[Any, ...], _SCVGraphEntry] = {} + self.enabled = torch.cuda.is_available() + + def run( + self, + spec_decode_metadata: SpecDecodeMetadata, + sampled_token_ids: torch.Tensor, + total_tokens: int, + ) -> torch.Tensor | None: + if not self.enabled: + return None + num_reqs = len(spec_decode_metadata.num_draft_tokens) + max_spec_len = spec_decode_metadata.max_spec_len + key = ( + num_reqs, + max_spec_len, + sampled_token_ids.shape[1], + total_tokens, + sampled_token_ids.dtype, + ) + entry = self.entries.get(key) + need_capture = False + if entry is None: + entry = _SCVGraphEntry( + num_reqs=num_reqs, + max_spec_len=max_spec_len, + total_tokens=total_tokens, + sampled_shape=sampled_token_ids[:, :max_spec_len].shape, + sampled_dtype=sampled_token_ids.dtype, + draft_dtype=spec_decode_metadata.draft_token_ids.dtype, + device=self.device, + ) + self.entries[key] = entry + need_capture = True + try: + sampled_view = sampled_token_ids[:, :max_spec_len] + entry.sampled_buffer.copy_(sampled_view) + draft_ids = spec_decode_metadata.draft_token_ids.to(self.device) + entry.draft_buffer.zero_() + entry.draft_buffer[: draft_ids.numel()].copy_(draft_ids) + num_tokens_tensor = torch.tensor( + spec_decode_metadata.num_draft_tokens, + device=self.device, + dtype=torch.int32, + ) + entry.num_tokens_buffer.copy_(num_tokens_tensor) + cu_tensor = spec_decode_metadata.cu_num_draft_tokens.to( + device=self.device, dtype=torch.int32 + ) + entry.cu_buffer.copy_(cu_tensor) + if need_capture: + entry.capture() + return entry.run() + except RuntimeError as exc: + logger.warning("SCV graph execution disabled: %s", exc) + self.enabled = False + self.entries.clear() + return None From 04a8f535917e20bbcc1d426186aed8897e32711b Mon Sep 17 00:00:00 2001 From: yuz207 Date: Tue, 14 Oct 2025 22:55:41 +0000 Subject: [PATCH 4/9] fix: guard SCV mode for test harness --- vllm/v1/worker/gpu_model_runner.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fae0fbc9ab9a..9aec9bde4985 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -514,6 +514,8 @@ def __init__( self._scv_graph_executor: SCVGraphExecutor | None = None def _scv_enabled(self) -> bool: + if not hasattr(self, "_scv_mode"): + self._scv_mode = envs.VLLM_SCV_MODE.lower() if self._scv_mode not in ("off", "graph", "adaptive"): logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode) self._scv_mode = "off" @@ -2385,10 +2387,12 @@ def _scv_vectorized_mask( cu = spec_decode_metadata.cu_num_draft_tokens.to(device=device) - if self._scv_mode == "graph": - if self._scv_graph_executor is None: - self._scv_graph_executor = SCVGraphExecutor(device) - mask = self._scv_graph_executor.run( + if hasattr(self, "_scv_mode") and self._scv_mode == "graph": + executor = getattr(self, "_scv_graph_executor", None) + if executor is None: + executor = SCVGraphExecutor(device) + self._scv_graph_executor = executor + mask = executor.run( spec_decode_metadata, sampled_token_ids, total_tokens ) if mask is not None: From 54878097cd3b82483ab0ff30da6e3b19d5f8031d Mon Sep 17 00:00:00 2001 From: yuz207 Date: Tue, 14 Oct 2025 22:58:46 +0000 Subject: [PATCH 5/9] feat(gpu_model_runner): add adaptive SCV mode to dynamically adjust speculation tokens Introduce an adaptive mode in the GPUModelRunner to dynamically compute and adjust the speculation token mask based on recent acceptance ratios during decoding. This update adds the `_scv_update_controller` method to modify the number of speculative tokens used, aiming to maintain a target acceptance ratio, improving decoding efficiency and performance. Co-authored-by: terragon-labs[bot] --- vllm/v1/worker/gpu_model_runner.py | 43 ++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9aec9bde4985..eeed8ee7a6c5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2398,6 +2398,18 @@ def _scv_vectorized_mask( if mask is not None: return mask + if hasattr(self, "_scv_mode") and self._scv_mode == "adaptive": + mask = self._scv_compute_mask( + draft_ids, + num_draft_tensor, + cu, + sampled_token_ids, + max_spec_len, + total_tokens, + ) + self._scv_update_controller(spec_decode_metadata, mask) + return mask + mask = self._scv_compute_mask( draft_ids, num_draft_tensor, @@ -2449,6 +2461,37 @@ def _scv_compute_mask( mask_flat = pos_in_req < accepted_broadcast return mask_flat + def _scv_update_controller( + self, + spec_decode_metadata: SpecDecodeMetadata, + mask: torch.Tensor, + ) -> None: + target_ratio = 0.6 + alpha = 0.2 + accepted = int(mask.sum().item()) + total = max(mask.numel(), 1) + ratio = accepted / total + prev = getattr(self, "_scv_accept_ratio", target_ratio) + new_ratio = (1 - alpha) * prev + alpha * ratio + self._scv_accept_ratio = new_ratio + + draft_model_config = getattr(self.speculative_config, "draft_model_config", None) + if draft_model_config is None or not hasattr(self.speculative_config, "num_speculative_tokens"): + return + + base_k = self.speculative_config.num_speculative_tokens + k_min = max(1, base_k // 4) + k_max = max(1, base_k * 2) + + if new_ratio < target_ratio * 0.8: + new_k = max(k_min, base_k - 1) + elif new_ratio > target_ratio * 1.2: + new_k = min(k_max, base_k + 1) + else: + new_k = base_k + + self.speculative_config.num_speculative_tokens = new_k + def _bookkeeping_sync( self, scheduler_output: "SchedulerOutput", From 132dc61b55e8337c27dbc0c2a15e880a4de03b82 Mon Sep 17 00:00:00 2001 From: yuz207 Date: Tue, 14 Oct 2025 23:01:40 +0000 Subject: [PATCH 6/9] test(deferred_writer): add test for _build_nwor_acceptance_mask in SCV adaptive mode Add a new unit test `test_scv_vectorized_mask_matches_reference` to validate the behavior of the `_build_nwor_acceptance_mask` method in the GPUModelRunner class configured with SCV adaptive mode. This test ensures the mask output matches the expected reference. Co-authored-by: terragon-labs[bot] --- tests/v1/test_deferred_writer.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/v1/test_deferred_writer.py b/tests/v1/test_deferred_writer.py index 16b65a08b7bb..91496757fe69 100644 --- a/tests/v1/test_deferred_writer.py +++ b/tests/v1/test_deferred_writer.py @@ -196,6 +196,17 @@ def test_nwor_immediate_mode_skips_window(): assert manager.get_mode() == "immediate" +def test_scv_vectorized_mask_matches_reference(): + metadata = _make_metadata([1, 2, 3, 4], [4]) + sampled = torch.tensor([[1, 2, 0, 4]], dtype=torch.int32) + + runner = GPUModelRunner.__new__(GPUModelRunner) + runner._scv_mode = "adaptive" + + mask = runner._build_nwor_acceptance_mask(metadata, sampled) + assert mask.tolist() == [True, True, False, False] + + def test_commit_failure_triggers_fallback_metrics(): manager = DeferredWriteManager() assert manager.begin_window([1]) From 078fc37a55b219bf21812fce322fe93f4858f7df Mon Sep 17 00:00:00 2001 From: yuz207 Date: Wed, 15 Oct 2025 00:44:24 +0000 Subject: [PATCH 7/9] fix: restore cached output buffers in GPUModelRunner __init__ --- vllm/v1/worker/gpu_model_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eeed8ee7a6c5..921bf3ede886 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -520,6 +520,8 @@ def _scv_enabled(self) -> bool: logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode) self._scv_mode = "off" return self._scv_mode != "off" + + # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( From 87a0206b0bf0611c6b3c844c445e02d2f704b234 Mon Sep 17 00:00:00 2001 From: yuz207 Date: Wed, 15 Oct 2025 00:51:35 +0000 Subject: [PATCH 8/9] fix: guard SCV adaptive controller when speculative_config missing --- vllm/v1/worker/gpu_model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 921bf3ede886..93bf8d5eedd4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2477,11 +2477,11 @@ def _scv_update_controller( new_ratio = (1 - alpha) * prev + alpha * ratio self._scv_accept_ratio = new_ratio - draft_model_config = getattr(self.speculative_config, "draft_model_config", None) - if draft_model_config is None or not hasattr(self.speculative_config, "num_speculative_tokens"): + speculative_config = getattr(self, "speculative_config", None) + if speculative_config is None or not hasattr(speculative_config, "num_speculative_tokens"): return - base_k = self.speculative_config.num_speculative_tokens + base_k = speculative_config.num_speculative_tokens k_min = max(1, base_k // 4) k_max = max(1, base_k * 2) @@ -2492,7 +2492,7 @@ def _scv_update_controller( else: new_k = base_k - self.speculative_config.num_speculative_tokens = new_k + speculative_config.num_speculative_tokens = new_k def _bookkeeping_sync( self, From 4fdd1a895b57407bc0e88f625d3999ad8d4f556f Mon Sep 17 00:00:00 2001 From: yuz207 Date: Wed, 15 Oct 2025 01:25:01 +0000 Subject: [PATCH 9/9] refactor(gpu_model_runner): move _scv_enabled method to follow initialization code The _scv_enabled method was relocated within the GPUModelRunner class to follow the initialization code block, improving code readability and organization without changing functionality. Co-authored-by: terragon-labs[bot] --- vllm/v1/worker/gpu_model_runner.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 93bf8d5eedd4..b84256dec815 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -512,16 +512,6 @@ def __init__( self._latest_nwor_window_metrics: dict[str, int | str] | None = None self._scv_mode = envs.VLLM_SCV_MODE.lower() self._scv_graph_executor: SCVGraphExecutor | None = None - - def _scv_enabled(self) -> bool: - if not hasattr(self, "_scv_mode"): - self._scv_mode = envs.VLLM_SCV_MODE.lower() - if self._scv_mode not in ("off", "graph", "adaptive"): - logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode) - self._scv_mode = "off" - return self._scv_mode != "off" - - # Cached outputs. self._draft_token_ids: list[list[int]] | torch.Tensor | None = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( @@ -531,6 +521,14 @@ def _scv_enabled(self) -> bool: pin_memory=self.pin_memory, ) + def _scv_enabled(self) -> bool: + if not hasattr(self, "_scv_mode"): + self._scv_mode = envs.VLLM_SCV_MODE.lower() + if self._scv_mode not in ("off", "graph", "adaptive"): + logger.warning("SCV: unsupported mode '%s', disabling.", self._scv_mode) + self._scv_mode = "off" + return self._scv_mode != "off" + def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache()