From a8a9f38431226db2db5d8a66a4ff62357fb4ea1a Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Tue, 22 Jul 2025 02:02:22 -0700 Subject: [PATCH 01/22] feat: basic support of nano split. Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 3 +- vllm/compilation/nano_split.py | 249 +++++++++++++++++++++++++++++ vllm/forward_context.py | 5 + vllm/v1/worker/gpu_model_runner.py | 118 +++++++++++++- 4 files changed, 373 insertions(+), 2 deletions(-) create mode 100644 vllm/compilation/nano_split.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 673fb5866234..c8c457e743f3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -587,7 +587,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: - return self.split_gm + from vllm.compilation.nano_split import init_split_manager_and_get_callable + return init_split_manager_and_get_callable(self.split_gm) # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode diff --git a/vllm/compilation/nano_split.py b/vllm/compilation/nano_split.py new file mode 100644 index 000000000000..5f7eadcadafd --- /dev/null +++ b/vllm/compilation/nano_split.py @@ -0,0 +1,249 @@ +import copy +from dataclasses import dataclass +import torch +from typing import Callable, Dict, List, Tuple, Optional, Set + +@dataclass +class InputInfo: + batch_size: int + num_tokens: list[int] + cached_seqlens: list[int] + + +@dataclass +class NanoBatchSplitConfig: + split_indices: List[int] + batch_sizes: List[int] + + +class HookWrapper(torch.nn.Module): + def __init__(self, hook: Callable): + super().__init__() + self.hook = hook + + def forward(self, *args, **kwargs): + self.hook(*args, **kwargs) + + +class NanoBatchSplit: + def __init__(self): + self.input_splits: Dict[torch.fx.Node, List[torch.fx.Node]] = {} + self.node_splits: Dict[torch.fx.Node, List[torch.fx.Node]] = {} + self.weight_nodes: Set[torch.fx.Node] = set() + self.splittable_inputs: List[torch.fx.Node] = [] + self.graph_module: Optional[torch.fx.GraphModule] = None + self.original_graph: torch.fx.Graph + self.base_graph: Optional[torch.fx.Graph] = None + self.new_graph: Optional[torch.fx.Graph] = None + + def _init_placeholders(self) -> None: + batch_size: Optional[torch.SymInt] = None + assert self.base_graph is not None + self.base_graph.call_module("pre_forward_hook", args=()) + for node in self.original_graph.nodes: + # Skip computation nodes + if node.op != "placeholder": + continue + + # We assume the batch size is the first argument + if batch_size is None: + arg = node.meta["example_value"] + if not isinstance(arg, torch.SymInt): + raise ValueError("Batch size is not set") + batch_size = arg + else: + shape = node.meta["example_value"].shape + if shape[0] == batch_size: + self.splittable_inputs.append(node) + print(f"Found splittable input: {node.name} with shape {shape}") + else: + self.weight_nodes.add(node) + print(f"Found weight tensor: {node.name} with shape {shape}") + # Copy all placeholder nodes to the new graph + self.base_graph.node_copy(node, arg_transform=lambda n: n) + + def _init_input_splits(self, split_indices: List[int]) -> None: + num_splits = len(split_indices) - 1 + assert self.new_graph is not None + for node in self.splittable_inputs: + self.input_splits[node] = [] + for i in range(num_splits): + start_idx = split_indices[i] + end_idx = split_indices[i + 1] + slice_node = self.new_graph.call_function( + lambda x, start, end: x[start:end], + args=(node, start_idx, end_idx), + ) + self.input_splits[node].append(slice_node) + + def _replicate_computations(self, split_indices: List[int]) -> None: + num_splits = len(split_indices) - 1 + assert self.new_graph is not None + print(f"Replicating computations for {num_splits} splits") + for node in self.original_graph.nodes: + if node.op in ["placeholder", "output"]: + continue + print(f"Processing node: {node.name}, op: {node.op}, args: {len(node.args) if node.args else 0}") + splits = [] + for split_idx in range(num_splits): + new_args = self._get_split_args(node.args, split_idx, split_indices) + new_kwargs = self._get_split_kwargs(node.kwargs, split_idx) + orig_vals = list(node.args) + list(node.kwargs.values()) + new_vals = list(new_args) + list(new_kwargs.values()) + orig_to_new = {o: n for o, n in zip(orig_vals, new_vals)} + # Call pre_op_hook with proper arguments + self.new_graph.call_module( + "pre_op_hook", + args=(node.name, split_idx, new_args, new_kwargs), + ) + new_node = self.new_graph.node_copy( + node, arg_transform=lambda n: orig_to_new[n] + ) + splits.append(new_node) + + self.node_splits[node] = splits + print( + f"Replicated computation node {node.name} into {num_splits} parts" + ) + + def _handle_outputs(self) -> None: + """Handle output nodes by concatenating split outputs and cleaning up original computations.""" + assert self.new_graph is not None + output_nodes = [ + node for node in self.original_graph.nodes if node.op == "output" + ] + assert len(output_nodes) == 1, f"Expected 1 output node, found {len(output_nodes)}" + output_node = output_nodes[0] + + # Find the original computation that feeds into this output + if not output_node.args: + raise ValueError("Output node has no arguments") + original_outputs = output_node.args[0] + is_tuple = isinstance(original_outputs, tuple) + if not isinstance(original_outputs, tuple): + original_outputs = (original_outputs,) + new_outputs = [] + + for original_output in original_outputs: + if original_output in self.node_splits: + # Get all split outputs + split_outputs = self.node_splits[original_output] + + # Create concatenation node + if len(split_outputs) == 1: + # If there's only one split, no need to concatenate + concat_node = split_outputs[0] + else: + # Create concatenation node + concat_node = self.new_graph.call_function( + torch.cat, + args=(split_outputs, 0), # Concatenate along first dimension + ) + + new_outputs.append(concat_node) + print(f"Concatenated {len(split_outputs)} output splits") + else: + raise ValueError( + f"Original output {original_output} not found in node_splits" + ) + + self.new_graph.output(tuple(new_outputs) if is_tuple else new_outputs[0]) + + def _get_split_args(self, args: Tuple, split_idx: int, split_indices: List[int]) -> Tuple: + """Get arguments for a specific split.""" + new_args = [] + + for arg in args: + if isinstance(arg, torch.fx.Node): + if arg in self.input_splits: + new_args.append(self.input_splits[arg][split_idx]) + elif arg in self.node_splits: + new_args.append(self.node_splits[arg][split_idx]) + elif arg in self.weight_nodes: + # Weight tensors are shared across splits + new_args.append(arg) + elif isinstance(arg.meta["example_value"], torch.SymInt): + new_args.append(split_indices[split_idx + 1] - split_indices[split_idx]) + else: + new_args.append(arg) + else: + new_args.append(arg) + + return tuple(new_args) + + def _get_split_kwargs(self, kwargs: Dict, split_idx: int) -> Dict: + """Get keyword arguments for a specific split.""" + new_kwargs = {} + + for key, value in kwargs.items(): + if isinstance(value, torch.fx.Node): + if value in self.input_splits: + new_kwargs[key] = self.input_splits[value][split_idx] + elif value in self.node_splits: + new_kwargs[key] = self.node_splits[value][split_idx] + elif value in self.weight_nodes: + # Weight tensors are shared across splits + new_kwargs[key] = value + else: + new_kwargs[key] = value + else: + new_kwargs[key] = value + + return new_kwargs + + def auto_search_and_split( + self, + input_info: InputInfo, + ) -> list[int]: + total_batch_size = input_info.batch_size + if total_batch_size == 1: + batch_sizes = [1] + split_indices = [0, input_info.num_tokens[0]] + else: + batch_sizes = [1, total_batch_size - 1] + split_indices = [0, input_info.num_tokens[0], sum(input_info.num_tokens)] + assert self.base_graph is not None + self.new_graph = copy.deepcopy(self.base_graph) + self._init_input_splits(split_indices) + self._replicate_computations(split_indices) + self._handle_outputs() + assert self.graph_module is not None + self.graph_module.graph = self.new_graph + print(self.graph_module.code) + setattr(self.graph_module, "cached_config", NanoBatchSplitConfig(split_indices, batch_sizes)) + return batch_sizes + + def init_callable(self, graph_module: torch.fx.GraphModule) -> Callable: + self.base_graph = torch.fx.Graph() + self.graph_module = graph_module + self.original_graph = graph_module.graph + self._init_placeholders() + return self.graph_module + + +_split_manager = None + + +def init_split_manager_and_get_callable(graph_module: torch.fx.GraphModule) -> Callable: + global _split_manager + if _split_manager is None: + _split_manager = NanoBatchSplit() + return _split_manager.init_callable(graph_module) + + +def auto_search_and_split(input_info: InputInfo) -> list[int]: + global _split_manager + if _split_manager is None: + raise ValueError("Split manager not initialized") + return _split_manager.auto_search_and_split(input_info) + + +def set_forward_hook( + pre_forward_hook: Callable[[], None], + pre_op_hook: Callable[[str, int, Tuple, Dict], None] +) -> None: + global _split_manager + if _split_manager is None: + raise ValueError("Split manager not initialized") + setattr(_split_manager.graph_module, "pre_forward_hook", HookWrapper(pre_forward_hook)) + setattr(_split_manager.graph_module, "pre_op_hook", HookWrapper(pre_op_hook)) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feeaf..dddb3ef4c81d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -108,6 +108,11 @@ def get_forward_context() -> ForwardContext: return _forward_context +def override_forward_context(forward_context: ForwardContext) -> None: + global _forward_context + _forward_context = forward_context + + @contextmanager def set_forward_context( attn_metadata: Any, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 670e653929ce..5baf7ad92681 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,6 +18,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter +from vllm.compilation.nano_split import InputInfo, auto_search_and_split, set_forward_hook from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState @@ -27,7 +28,7 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, +from vllm.forward_context import (DPMetadata, ForwardContext, get_forward_context, override_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase @@ -789,6 +790,120 @@ def _prepare_inputs( return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata) + + def _split_inputs_for_nano_split( + self, + scheduler_output: "SchedulerOutput", + ) -> None: + """ + For each group of requests (as defined by split_indices), generate + separate attention metadata. + + Args: + scheduler_output: SchedulerOutput for the whole batch. + nano_batch_size: List of nano batch sizes. + + Returns: + List of attention_metadata dicts, one per group. + """ + req_ids = self.input_batch.req_ids + num_reqs = len(req_ids) + num_tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] + cached_seqlens = self.input_batch.num_computed_tokens_cpu[:num_reqs].tolist() + batch_sizes = auto_search_and_split( + InputInfo( + batch_size=num_reqs, + num_tokens=num_tokens, + cached_seqlens=cached_seqlens, + ) + ) + attn_metadatas = [] + + start_req_idx = 0 + end_req_idx = 0 + for nano_batch_size in batch_sizes: + start_req_idx = end_req_idx + end_req_idx = start_req_idx + nano_batch_size + nano_batch_req_ids = req_ids[start_req_idx:end_req_idx] + + # Gather per-request info for this group + nano_batch_num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in nano_batch_req_ids], + dtype=np.int32 + ) + nano_batch_cu_num_tokens, nano_batch_arange = self._get_cumsum_and_arange(nano_batch_num_scheduled_tokens) + nano_batch_total_tokens = int(nano_batch_cu_num_tokens[-1]) + nano_batch_req_indices = np.repeat(np.arange(len(nano_batch_req_ids)), nano_batch_num_scheduled_tokens) + + # Compute positions for this group + nano_batch_positions_np = np.empty(nano_batch_total_tokens, dtype=np.int64) + np.add( + self.input_batch.num_computed_tokens_cpu[start_req_idx:end_req_idx][nano_batch_req_indices], + nano_batch_arange, + out=nano_batch_positions_np + ) + + # Prepare attention metadata for each KV cache group + nano_batch_attn_metadata = {} + for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[start_req_idx:end_req_idx] + slot_mapping = blk_table.slot_mapping[:nano_batch_total_tokens] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[start_req_idx:end_req_idx + 1], + query_start_loc_cpu=self.query_start_loc_cpu[start_req_idx:end_req_idx + 1], + seq_lens=self.seq_lens[start_req_idx:end_req_idx], + seq_lens_cpu=self.seq_lens_cpu[start_req_idx:end_req_idx], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[start_req_idx:end_req_idx], + num_reqs=nano_batch_size, + num_actual_tokens=nano_batch_total_tokens, + max_query_len=int(max(nano_batch_num_scheduled_tokens)), + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + ) + + if isinstance(kv_cache_group_spec.kv_cache_spec, ChunkedLocalAttentionSpec): + common_attn_metadata = make_local_attention_virtual_batches( + kv_cache_group_spec.kv_cache_spec.attention_chunk_size, + common_attn_metadata, self.cache_config.block_size) + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + builder = self.attn_metadata_builders[kv_cache_group_id] + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + nano_batch_num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks[kv_cache_group_id], + kv_cache_group_spec.kv_cache_spec, + builder, + ) + + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + + for layer_name in kv_cache_group_spec.layer_names: + nano_batch_attn_metadata[layer_name] = attn_metadata_i + + attn_metadatas.append(nano_batch_attn_metadata) + + assert end_req_idx == num_reqs, f"invalid nano batch size: {batch_sizes}" + forward_contexts = [ + ForwardContext( + no_compile_layers=self.vllm_config.compilation_config.static_forward_context, + virtual_engine=0, + attn_metadata=attn_metadata, + dp_metadata=None, + skip_cuda_graphs=True, + ) for attn_metadata in attn_metadatas + ] + def pre_forward_hook() -> None: + pass + def pre_op_hook(node_name: str, idx: int, args: tuple, kwargs: dict) -> None: + override_forward_context(forward_contexts[idx]) + set_forward_hook(pre_forward_hook, pre_op_hook) def _compute_cascade_attn_prefix_len( self, @@ -1396,6 +1511,7 @@ def execute_model( # If attention doesn't support CUDA Graphs for this batch, but we # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + self._split_inputs_for_nano_split(scheduler_output) # Run the model. # Use persistent buffers for CUDA graphs. From ad5fbcdd0bd126f12426901533fe8fd9d923b9cf Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Thu, 24 Jul 2025 02:42:34 -0700 Subject: [PATCH 02/22] finish compute comm overlap Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 3 + vllm/compilation/nano_split.py | 188 ++++++++++++++++++++--------- vllm/config.py | 1 + vllm/v1/worker/gpu_model_runner.py | 24 ++-- 4 files changed, 152 insertions(+), 64 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c8c457e743f3..175976c8bcad 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -234,6 +234,7 @@ def split_graph(graph: fx.GraphModule, # split graph by ops subgraph_id = 0 node_to_subgraph_id = {} + subgraph_to_tag = {} split_op_graphs = [] for node in graph.graph.nodes: if node.op in ("output", "placeholder"): @@ -242,6 +243,7 @@ def split_graph(graph: fx.GraphModule, subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) + subgraph_to_tag[subgraph_id] = str(node.target) subgraph_id += 1 else: node_to_subgraph_id[node] = subgraph_id @@ -268,6 +270,7 @@ def split_graph(graph: fx.GraphModule, module = getattr(split_gm, name) graph_id = int(name.replace("submod_", "")) + setattr(module, "tag", subgraph_to_tag.get(graph_id, "")) outputs.append( SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) diff --git a/vllm/compilation/nano_split.py b/vllm/compilation/nano_split.py index 5f7eadcadafd..dcb1223c6aa9 100644 --- a/vllm/compilation/nano_split.py +++ b/vllm/compilation/nano_split.py @@ -1,7 +1,8 @@ +import contextlib import copy from dataclasses import dataclass import torch -from typing import Callable, Dict, List, Tuple, Optional, Set +from typing import Callable, ContextManager, Dict, List, Tuple, Optional, Set @dataclass class InputInfo: @@ -16,30 +17,43 @@ class NanoBatchSplitConfig: batch_sizes: List[int] -class HookWrapper(torch.nn.Module): - def __init__(self, hook: Callable): +@dataclass +class NanoOpInfo: + gm: torch.fx.GraphModule + submod_name: str + tag: str + idx: int + args: tuple + kwargs: dict + + +class NanoOpWrapper(torch.nn.Module): + def __init__(self, gm: torch.fx.GraphModule, hook: List[Callable[[NanoOpInfo], ContextManager[None]]]): super().__init__() + self.gm = gm self.hook = hook - - def forward(self, *args, **kwargs): - self.hook(*args, **kwargs) - - -class NanoBatchSplit: - def __init__(self): - self.input_splits: Dict[torch.fx.Node, List[torch.fx.Node]] = {} - self.node_splits: Dict[torch.fx.Node, List[torch.fx.Node]] = {} - self.weight_nodes: Set[torch.fx.Node] = set() - self.splittable_inputs: List[torch.fx.Node] = [] - self.graph_module: Optional[torch.fx.GraphModule] = None - self.original_graph: torch.fx.Graph - self.base_graph: Optional[torch.fx.Graph] = None - self.new_graph: Optional[torch.fx.Graph] = None - def _init_placeholders(self) -> None: + def forward(self, submod_name: str, idx: int, args: tuple, kwargs: dict): + module = getattr(self.gm, submod_name) + tag = getattr(module, "tag", "") + with contextlib.ExitStack() as stack: + for hook in self.hook: + stack.enter_context(hook(NanoOpInfo(self.gm, submod_name, tag, idx, args, kwargs))) + output = module(*args, **kwargs) + return output + + +class NanoSplitManager: + def __init__(self, graph_module: torch.fx.GraphModule) -> None: + self.graph_module = graph_module + self.original_graph = graph_module.graph + self.base_graph = torch.fx.Graph() + + # Initialize the base graph batch_size: Optional[torch.SymInt] = None - assert self.base_graph is not None - self.base_graph.call_module("pre_forward_hook", args=()) + weight_nodes = set() + splittable_inputs = [] + base_graph = torch.fx.Graph() for node in self.original_graph.nodes: # Skip computation nodes if node.op != "placeholder": @@ -51,16 +65,40 @@ def _init_placeholders(self) -> None: if not isinstance(arg, torch.SymInt): raise ValueError("Batch size is not set") batch_size = arg - else: - shape = node.meta["example_value"].shape + elif isinstance(input_tensor := node.meta["example_value"], torch.Tensor): + shape = input_tensor.shape if shape[0] == batch_size: - self.splittable_inputs.append(node) + splittable_inputs.append(node) print(f"Found splittable input: {node.name} with shape {shape}") else: - self.weight_nodes.add(node) + weight_nodes.add(node) print(f"Found weight tensor: {node.name} with shape {shape}") # Copy all placeholder nodes to the new graph - self.base_graph.node_copy(node, arg_transform=lambda n: n) + base_graph.node_copy(node, arg_transform=lambda n: n) + self.base_graph = base_graph + self.splittable_inputs: List[torch.fx.Node] = splittable_inputs + self.weight_nodes: Set[torch.fx.Node] = weight_nodes + + # Nano split preparation + self.new_graph: Optional[torch.fx.Graph] = None + self.input_splits = {} + self.node_splits = {} + self.op_wrapper = NanoOpWrapper(self.graph_module, []) + + # Runtime preparation + self.cached_config: Optional[NanoBatchSplitConfig] = None + self.comm_stream = torch.cuda.Stream() + self.comp_stream = torch.cuda.Stream() + + + def get_callable(self) -> Callable: + def _forward(*args, **kwargs): + assert self.op_wrapper is not None + setattr(self.graph_module, "op_wrapper", self.op_wrapper) + output = self.graph_module(*args, **kwargs) + delattr(self.graph_module, "op_wrapper") + return output + return _forward def _init_input_splits(self, split_indices: List[int]) -> None: num_splits = len(split_indices) - 1 @@ -91,14 +129,16 @@ def _replicate_computations(self, split_indices: List[int]) -> None: orig_vals = list(node.args) + list(node.kwargs.values()) new_vals = list(new_args) + list(new_kwargs.values()) orig_to_new = {o: n for o, n in zip(orig_vals, new_vals)} - # Call pre_op_hook with proper arguments - self.new_graph.call_module( - "pre_op_hook", - args=(node.name, split_idx, new_args, new_kwargs), - ) - new_node = self.new_graph.node_copy( - node, arg_transform=lambda n: orig_to_new[n] - ) + + if node.op == "call_module": + new_node = self.new_graph.call_module( + "op_wrapper", + args=(str(node.target), split_idx, new_args, new_kwargs), + ) + else: + new_node = self.new_graph.node_copy( + node, arg_transform=lambda n: orig_to_new[n] + ) splits.append(new_node) self.node_splits[node] = splits @@ -191,7 +231,7 @@ def _get_split_kwargs(self, kwargs: Dict, split_idx: int) -> Dict: return new_kwargs - def auto_search_and_split( + def prepare_split( self, input_info: InputInfo, ) -> list[int]: @@ -200,8 +240,8 @@ def auto_search_and_split( batch_sizes = [1] split_indices = [0, input_info.num_tokens[0]] else: - batch_sizes = [1, total_batch_size - 1] - split_indices = [0, input_info.num_tokens[0], sum(input_info.num_tokens)] + batch_sizes = [total_batch_size // 2, total_batch_size - total_batch_size // 2] + split_indices = [0, sum(input_info.num_tokens[:total_batch_size // 2]), sum(input_info.num_tokens)] assert self.base_graph is not None self.new_graph = copy.deepcopy(self.base_graph) self._init_input_splits(split_indices) @@ -209,16 +249,50 @@ def auto_search_and_split( self._handle_outputs() assert self.graph_module is not None self.graph_module.graph = self.new_graph - print(self.graph_module.code) - setattr(self.graph_module, "cached_config", NanoBatchSplitConfig(split_indices, batch_sizes)) + self.cached_config = NanoBatchSplitConfig(split_indices, batch_sizes) + from torch._dynamo.utils import lazy_format_graph_code # type: ignore + print(lazy_format_graph_code("after nano split", self.graph_module)) return batch_sizes - - def init_callable(self, graph_module: torch.fx.GraphModule) -> Callable: - self.base_graph = torch.fx.Graph() - self.graph_module = graph_module - self.original_graph = graph_module.graph - self._init_placeholders() - return self.graph_module + + def prepare_runtime( + self, + *, + forward_hook: Optional[Callable] = None, + op_hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None, + ) -> None: + assert self.cached_config is not None + batch_sizes = self.cached_config.batch_sizes + comm_finished = [None for _ in range(len(batch_sizes))] + + @contextlib.contextmanager + def set_stream(op_info: NanoOpInfo): + if op_info.tag == "vllm.all_reduce": + torch.cuda.set_stream(self.comm_stream) # type: ignore + comm_finished[op_info.idx] = torch.cuda.Event() # type: ignore + else: + torch.cuda.set_stream(self.comp_stream) # type: ignore + if comm_finished[op_info.idx] is not None: + comm_finished[op_info.idx].wait() # type: ignore + comm_finished[op_info.idx] = None + try: + yield + finally: + if op_info.tag == "vllm.all_reduce": + comm_finished[op_info.idx].record() # type: ignore + + @contextlib.contextmanager + def nvtx_mark(op_info: NanoOpInfo): + try: + with torch.cuda.nvtx.range(f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}"): + yield + except Exception as e: + print(f"Error in nvtx_mark: {e}") + raise e + + hooks = [] if op_hook is None else [op_hook] + self.op_wrapper = NanoOpWrapper(self.graph_module, [set_stream, nvtx_mark] + hooks) + if forward_hook is not None: + self.graph_module.register_forward_hook(forward_hook) _split_manager = None @@ -227,23 +301,27 @@ def init_callable(self, graph_module: torch.fx.GraphModule) -> Callable: def init_split_manager_and_get_callable(graph_module: torch.fx.GraphModule) -> Callable: global _split_manager if _split_manager is None: - _split_manager = NanoBatchSplit() - return _split_manager.init_callable(graph_module) + _split_manager = NanoSplitManager(graph_module) + return _split_manager.get_callable() + -def auto_search_and_split(input_info: InputInfo) -> list[int]: +def prepare_split(input_info: InputInfo) -> list[int]: global _split_manager if _split_manager is None: raise ValueError("Split manager not initialized") - return _split_manager.auto_search_and_split(input_info) + return _split_manager.prepare_split(input_info) -def set_forward_hook( - pre_forward_hook: Callable[[], None], - pre_op_hook: Callable[[str, int, Tuple, Dict], None] +def prepare_runtime( + *, + forward_hook: Optional[Callable] = None, + op_hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None, ) -> None: global _split_manager if _split_manager is None: raise ValueError("Split manager not initialized") - setattr(_split_manager.graph_module, "pre_forward_hook", HookWrapper(pre_forward_hook)) - setattr(_split_manager.graph_module, "pre_op_hook", HookWrapper(pre_op_hook)) + _split_manager.prepare_runtime( + forward_hook=forward_hook, + op_hook=op_hook, + ) diff --git a/vllm/config.py b/vllm/config.py index 44106dd279b6..e44030ee56a7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4343,6 +4343,7 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.all_reduce" ] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5baf7ad92681..00c0fdbcfffa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,7 +18,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter -from vllm.compilation.nano_split import InputInfo, auto_search_and_split, set_forward_hook +from vllm.compilation.nano_split import InputInfo, NanoOpInfo, prepare_split, prepare_runtime from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState @@ -810,7 +810,7 @@ def _split_inputs_for_nano_split( num_reqs = len(req_ids) num_tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] cached_seqlens = self.input_batch.num_computed_tokens_cpu[:num_reqs].tolist() - batch_sizes = auto_search_and_split( + batch_sizes = prepare_split( InputInfo( batch_size=num_reqs, num_tokens=num_tokens, @@ -851,8 +851,8 @@ def _split_inputs_for_nano_split( slot_mapping = blk_table.slot_mapping[:nano_batch_total_tokens] common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[start_req_idx:end_req_idx + 1], - query_start_loc_cpu=self.query_start_loc_cpu[start_req_idx:end_req_idx + 1], + query_start_loc=self.query_start_loc[start_req_idx:end_req_idx + 1] - self.query_start_loc[start_req_idx], + query_start_loc_cpu=self.query_start_loc_cpu[start_req_idx:end_req_idx + 1] - self.query_start_loc_cpu[start_req_idx], seq_lens=self.seq_lens[start_req_idx:end_req_idx], seq_lens_cpu=self.seq_lens_cpu[start_req_idx:end_req_idx], num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[start_req_idx:end_req_idx], @@ -899,11 +899,17 @@ def _split_inputs_for_nano_split( skip_cuda_graphs=True, ) for attn_metadata in attn_metadatas ] - def pre_forward_hook() -> None: - pass - def pre_op_hook(node_name: str, idx: int, args: tuple, kwargs: dict) -> None: - override_forward_context(forward_contexts[idx]) - set_forward_hook(pre_forward_hook, pre_op_hook) + + @contextmanager + def op_hook(op_info: NanoOpInfo): + previous_context = get_forward_context() + override_forward_context(forward_contexts[op_info.idx]) + try: + yield + finally: + override_forward_context(previous_context) + + prepare_runtime(forward_hook=None, op_hook=op_hook) def _compute_cascade_attn_prefix_len( self, From 8712a266c9cfc1ef28d91cd2fc2070edeee42c84 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Mon, 28 Jul 2025 00:27:20 -0700 Subject: [PATCH 03/22] update Signed-off-by: Yi Pan --- vllm/compilation/nano_split.py | 15 +++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 8 ++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/nano_split.py b/vllm/compilation/nano_split.py index dcb1223c6aa9..ae408ceb776d 100644 --- a/vllm/compilation/nano_split.py +++ b/vllm/compilation/nano_split.py @@ -2,6 +2,8 @@ import copy from dataclasses import dataclass import torch +from flashinfer import green_ctx +import nvmath from typing import Callable, ContextManager, Dict, List, Tuple, Optional, Set @dataclass @@ -89,6 +91,12 @@ def __init__(self, graph_module: torch.fx.GraphModule) -> None: self.cached_config: Optional[NanoBatchSplitConfig] = None self.comm_stream = torch.cuda.Stream() self.comp_stream = torch.cuda.Stream() + self.comp_stream, _ = green_ctx.split_device_green_ctx_by_sm_count( # type: ignore + dev=torch.device(f"cuda:{torch.cuda.current_device()}"), + sm_counts=[120] + )[0] + handle = torch.cuda.current_blas_handle() + nvmath.bindings.cublas.set_sm_count_target(handle, 120) # type: ignore def get_callable(self) -> Callable: @@ -263,14 +271,19 @@ def prepare_runtime( assert self.cached_config is not None batch_sizes = self.cached_config.batch_sizes comm_finished = [None for _ in range(len(batch_sizes))] + comp_finished = [None for _ in range(len(batch_sizes))] @contextlib.contextmanager def set_stream(op_info: NanoOpInfo): if op_info.tag == "vllm.all_reduce": torch.cuda.set_stream(self.comm_stream) # type: ignore comm_finished[op_info.idx] = torch.cuda.Event() # type: ignore + if comp_finished[op_info.idx] is not None: + comp_finished[op_info.idx].wait() # type: ignore + comp_finished[op_info.idx] = None else: torch.cuda.set_stream(self.comp_stream) # type: ignore + comp_finished[op_info.idx] = torch.cuda.Event() # type: ignore if comm_finished[op_info.idx] is not None: comm_finished[op_info.idx].wait() # type: ignore comm_finished[op_info.idx] = None @@ -279,6 +292,8 @@ def set_stream(op_info: NanoOpInfo): finally: if op_info.tag == "vllm.all_reduce": comm_finished[op_info.idx].record() # type: ignore + else: + comp_finished[op_info.idx].record() # type: ignore @contextlib.contextmanager def nvtx_mark(op_info: NanoOpInfo): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00c0fdbcfffa..11350d921dd4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -796,15 +796,11 @@ def _split_inputs_for_nano_split( scheduler_output: "SchedulerOutput", ) -> None: """ - For each group of requests (as defined by split_indices), generate - separate attention metadata. + Split the input batch into nano batches and prepare + attention metadata for each nano batch. Args: scheduler_output: SchedulerOutput for the whole batch. - nano_batch_size: List of nano batch sizes. - - Returns: - List of attention_metadata dicts, one per group. """ req_ids = self.input_batch.req_ids num_reqs = len(req_ids) From c240236482ae9fb1161a6f50a3d2b33dd5f7eb06 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Fri, 1 Aug 2025 15:23:04 -0700 Subject: [PATCH 04/22] fix cpu overhead Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 19 +- vllm/compilation/nano_manager.py | 223 +++++++++++++++ vllm/compilation/nano_split.py | 457 +++++++++---------------------- vllm/compilation/nano_utils.py | 130 +++++++++ vllm/config.py | 3 +- 5 files changed, 490 insertions(+), 342 deletions(-) create mode 100644 vllm/compilation/nano_manager.py create mode 100644 vllm/compilation/nano_utils.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 175976c8bcad..3240b837f9f5 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -14,6 +14,7 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher +from vllm.compilation.nano_utils import tag_graph import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger @@ -234,7 +235,6 @@ def split_graph(graph: fx.GraphModule, # split graph by ops subgraph_id = 0 node_to_subgraph_id = {} - subgraph_to_tag = {} split_op_graphs = [] for node in graph.graph.nodes: if node.op in ("output", "placeholder"): @@ -243,7 +243,6 @@ def split_graph(graph: fx.GraphModule, subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) - subgraph_to_tag[subgraph_id] = str(node.target) subgraph_id += 1 else: node_to_subgraph_id[node] = subgraph_id @@ -270,7 +269,6 @@ def split_graph(graph: fx.GraphModule, module = getattr(split_gm, name) graph_id = int(name.replace("submod_", "")) - setattr(module, "tag", subgraph_to_tag.get(graph_id, "")) outputs.append( SplitItem(name, graph_id, (graph_id in split_op_graphs), module)) @@ -551,8 +549,12 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.graph = graph self.configure_post_pass() - self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_config.splitting_ops) + if self.vllm_config.model_config.enable_nano_split: + self.split_gm, self.piecewise_graphs = split_graph(graph, [ + "vllm.all_reduce" + ]) + else: + self.split_gm, self.piecewise_graphs = split_graph(graph, self.compilation_config.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code @@ -590,8 +592,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: - from vllm.compilation.nano_split import init_split_manager_and_get_callable - return init_split_manager_and_get_callable(self.split_gm) + if self.vllm_config.model_config.enable_nano_split: + from vllm.compilation.nano_manager import init_split_manager_and_get_callable + return init_split_manager_and_get_callable(self.split_gm) + else: + return self.split_gm # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode diff --git a/vllm/compilation/nano_manager.py b/vllm/compilation/nano_manager.py new file mode 100644 index 000000000000..1c689feec509 --- /dev/null +++ b/vllm/compilation/nano_manager.py @@ -0,0 +1,223 @@ +import contextlib +import copy +import torch +from typing import Callable, ContextManager, List, Optional +from flashinfer.green_ctx import split_device_green_ctx_by_sm_count + +from vllm.compilation.nano_utils import ( + NanoOpInfo, + NanoSplitConfig, + FakeModule, + display_graph, + get_split_config, + tag_graph, +) +from vllm.compilation.nano_split import ( + analyze_graph, + concat_outputs, + split_computations, + split_inputs, +) + + +class NanoSplitManager: + def __init__( + self, graph_module: torch.fx.GraphModule, max_nano_splits: int = 2 + ) -> None: + self.graph_module = graph_module + self.original_graph = graph_module.graph + + # Nano split preparation + self.max_nano_splits = max_nano_splits + self.new_graphs = {1: self.original_graph} + + # Runtime preparation + self.cached_config: Optional[NanoSplitConfig] = None + self.comm_stream: Optional[torch.Stream] = None + self.comp_stream: Optional[torch.Stream] = None + self.comp_stream, self.comm_stream = split_device_green_ctx_by_sm_count( + dev=torch.device(f"cuda:{torch.cuda.current_device()}"), + sm_counts=[112] + )[0] + + # Initialize the base graph + tag_graph(self.graph_module, { + "vllm.unified_attention": "attention", + "vllm.unified_attention_with_output": "attention", + "vllm.all_reduce": "all_reduce", + }) + splittable_inputs, base_graph = analyze_graph(self.original_graph) + for num_splits in range(2, max_nano_splits + 1): + new_graph = copy.deepcopy(base_graph) + nano_batch_sizes = [] + for i in range(num_splits): + nano_batch_sizes.append( + new_graph.call_module( + "get_batch_size", + args=(i,), + kwargs={}, + ) + ) + mapping = split_inputs( + new_graph, splittable_inputs, "split_input", num_splits + ) + split_computations( + self.original_graph, + new_graph, + mapping, + nano_batch_sizes, + "op_wrapper", + num_splits, + ) + concat_outputs( + self.original_graph, + new_graph, + mapping, + ) + self.new_graphs[num_splits] = new_graph + print(new_graph) + self.graph_module.graph = new_graph + display_graph(self.graph_module, f"after nano split {num_splits}") + + @staticmethod + def get_batch_size(idx: int, cached_config: NanoSplitConfig): + return cached_config.num_tokens[idx] + + @staticmethod + def split_input(x: torch.Tensor, idx: int, cached_config: NanoSplitConfig): + return x[ + cached_config.split_indices[idx] : cached_config.split_indices[idx + 1] + ] + + @staticmethod + def op_wrapper( + submod_name: str, + idx: int, + args: tuple, + kwargs: dict, + gm: torch.fx.GraphModule, + hooks: List[Callable[[NanoOpInfo], ContextManager[None]]], + ): + module = getattr(gm, submod_name) + tag = getattr(module, "tag", "") + with contextlib.ExitStack() as stack: + for hook in hooks: + stack.enter_context( + hook(NanoOpInfo(submod_name, tag, idx, args, kwargs)) + ) + output = module(*args, **kwargs) + return output + + def get_callable(self) -> Callable: + def _forward(*args, **kwargs): + if self.cached_config is None: + self.graph_module.graph = self.original_graph + return self.graph_module(*args, **kwargs) + + num_nano_batches = self.cached_config.num_nano_batches + # NOTE(yi): This can be time consuming + if self.graph_module.graph != self.new_graphs[num_nano_batches]: + self.graph_module.graph = self.new_graphs[num_nano_batches] + comm_finished = [None for _ in range(num_nano_batches)] + comp_finished = [None for _ in range(num_nano_batches)] + + @contextlib.contextmanager + def set_stream(op_info: NanoOpInfo): + if op_info.tag == "all_reduce": + torch.cuda.set_stream(self.comm_stream) # type: ignore + comm_finished[op_info.idx] = torch.cuda.Event() # type: ignore + if comp_finished[op_info.idx] is not None: + comp_finished[op_info.idx].wait() # type: ignore + comp_finished[op_info.idx] = None + else: + torch.cuda.set_stream(self.comp_stream) # type: ignore + comp_finished[op_info.idx] = torch.cuda.Event() # type: ignore + if comm_finished[op_info.idx] is not None: + comm_finished[op_info.idx].wait() # type: ignore + comm_finished[op_info.idx] = None + try: + yield + finally: + if op_info.tag == "all_reduce": + comm_finished[op_info.idx].record() # type: ignore + else: + comp_finished[op_info.idx].record() # type: ignore + + @contextlib.contextmanager + def nvtx_mark(op_info: NanoOpInfo): + try: + with torch.cuda.nvtx.range( + f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}" + ): + yield + except Exception as e: + print(f"Error in nvtx_mark: {e}") + raise e + + # Register fake modules + assert self.hook is not None + op_wrapper = FakeModule( + NanoSplitManager.op_wrapper, + gm=self.graph_module, + hooks=[ + set_stream, + nvtx_mark, + self.hook, + ], + ) + setattr(self.graph_module, "op_wrapper", op_wrapper) + get_batch_size = FakeModule( + NanoSplitManager.get_batch_size, + cached_config=self.cached_config, + ) + setattr(self.graph_module, "get_batch_size", get_batch_size) + split_input = FakeModule( + NanoSplitManager.split_input, + cached_config=self.cached_config, + ) + setattr(self.graph_module, "split_input", split_input) + output = self.graph_module(*args, **kwargs) + delattr(self.graph_module, "op_wrapper") + delattr(self.graph_module, "get_batch_size") + delattr(self.graph_module, "split_input") + return output + + return _forward + + def prepare( + self, + batch_size: int, + num_tokens: List[int], + cached_seqlens: List[int], + get_hooks_fn: Callable[..., Callable[[NanoOpInfo], ContextManager[None]]], + *args, + **kwargs, + ) -> None: + self.cached_config = get_split_config(batch_size, num_tokens, cached_seqlens) + self.hook = get_hooks_fn(self.cached_config, *args, **kwargs) + + +_split_manager = None + + +def init_split_manager_and_get_callable(graph_module: torch.fx.GraphModule) -> Callable: + global _split_manager + if _split_manager is None: + _split_manager = NanoSplitManager(graph_module) + return _split_manager.get_callable() + + +def prepare_nano_split( + batch_size: int, + num_tokens: List[int], + cached_seqlens: List[int], + get_hooks_fn: Callable[..., Callable[[NanoOpInfo], ContextManager[None]]], + *args, + **kwargs, +) -> None: + global _split_manager + if _split_manager is None: + raise ValueError("Split manager not initialized") + _split_manager.prepare( + batch_size, num_tokens, cached_seqlens, get_hooks_fn, *args, **kwargs + ) diff --git a/vllm/compilation/nano_split.py b/vllm/compilation/nano_split.py index ae408ceb776d..07c13e464c35 100644 --- a/vllm/compilation/nano_split.py +++ b/vllm/compilation/nano_split.py @@ -1,342 +1,131 @@ -import contextlib -import copy -from dataclasses import dataclass import torch -from flashinfer import green_ctx -import nvmath -from typing import Callable, ContextManager, Dict, List, Tuple, Optional, Set - -@dataclass -class InputInfo: - batch_size: int - num_tokens: list[int] - cached_seqlens: list[int] - - -@dataclass -class NanoBatchSplitConfig: - split_indices: List[int] - batch_sizes: List[int] - - -@dataclass -class NanoOpInfo: - gm: torch.fx.GraphModule - submod_name: str - tag: str - idx: int - args: tuple - kwargs: dict - - -class NanoOpWrapper(torch.nn.Module): - def __init__(self, gm: torch.fx.GraphModule, hook: List[Callable[[NanoOpInfo], ContextManager[None]]]): - super().__init__() - self.gm = gm - self.hook = hook - - def forward(self, submod_name: str, idx: int, args: tuple, kwargs: dict): - module = getattr(self.gm, submod_name) - tag = getattr(module, "tag", "") - with contextlib.ExitStack() as stack: - for hook in self.hook: - stack.enter_context(hook(NanoOpInfo(self.gm, submod_name, tag, idx, args, kwargs))) - output = module(*args, **kwargs) - return output - - -class NanoSplitManager: - def __init__(self, graph_module: torch.fx.GraphModule) -> None: - self.graph_module = graph_module - self.original_graph = graph_module.graph - self.base_graph = torch.fx.Graph() - - # Initialize the base graph - batch_size: Optional[torch.SymInt] = None - weight_nodes = set() - splittable_inputs = [] - base_graph = torch.fx.Graph() - for node in self.original_graph.nodes: - # Skip computation nodes - if node.op != "placeholder": - continue - - # We assume the batch size is the first argument - if batch_size is None: - arg = node.meta["example_value"] - if not isinstance(arg, torch.SymInt): - raise ValueError("Batch size is not set") - batch_size = arg - elif isinstance(input_tensor := node.meta["example_value"], torch.Tensor): - shape = input_tensor.shape - if shape[0] == batch_size: - splittable_inputs.append(node) - print(f"Found splittable input: {node.name} with shape {shape}") - else: - weight_nodes.add(node) - print(f"Found weight tensor: {node.name} with shape {shape}") - # Copy all placeholder nodes to the new graph - base_graph.node_copy(node, arg_transform=lambda n: n) - self.base_graph = base_graph - self.splittable_inputs: List[torch.fx.Node] = splittable_inputs - self.weight_nodes: Set[torch.fx.Node] = weight_nodes - - # Nano split preparation - self.new_graph: Optional[torch.fx.Graph] = None - self.input_splits = {} - self.node_splits = {} - self.op_wrapper = NanoOpWrapper(self.graph_module, []) - - # Runtime preparation - self.cached_config: Optional[NanoBatchSplitConfig] = None - self.comm_stream = torch.cuda.Stream() - self.comp_stream = torch.cuda.Stream() - self.comp_stream, _ = green_ctx.split_device_green_ctx_by_sm_count( # type: ignore - dev=torch.device(f"cuda:{torch.cuda.current_device()}"), - sm_counts=[120] - )[0] - handle = torch.cuda.current_blas_handle() - nvmath.bindings.cublas.set_sm_count_target(handle, 120) # type: ignore - - - def get_callable(self) -> Callable: - def _forward(*args, **kwargs): - assert self.op_wrapper is not None - setattr(self.graph_module, "op_wrapper", self.op_wrapper) - output = self.graph_module(*args, **kwargs) - delattr(self.graph_module, "op_wrapper") - return output - return _forward - - def _init_input_splits(self, split_indices: List[int]) -> None: - num_splits = len(split_indices) - 1 - assert self.new_graph is not None - for node in self.splittable_inputs: - self.input_splits[node] = [] - for i in range(num_splits): - start_idx = split_indices[i] - end_idx = split_indices[i + 1] - slice_node = self.new_graph.call_function( - lambda x, start, end: x[start:end], - args=(node, start_idx, end_idx), - ) - self.input_splits[node].append(slice_node) - - def _replicate_computations(self, split_indices: List[int]) -> None: - num_splits = len(split_indices) - 1 - assert self.new_graph is not None - print(f"Replicating computations for {num_splits} splits") - for node in self.original_graph.nodes: - if node.op in ["placeholder", "output"]: - continue - print(f"Processing node: {node.name}, op: {node.op}, args: {len(node.args) if node.args else 0}") - splits = [] - for split_idx in range(num_splits): - new_args = self._get_split_args(node.args, split_idx, split_indices) - new_kwargs = self._get_split_kwargs(node.kwargs, split_idx) - orig_vals = list(node.args) + list(node.kwargs.values()) - new_vals = list(new_args) + list(new_kwargs.values()) - orig_to_new = {o: n for o, n in zip(orig_vals, new_vals)} - - if node.op == "call_module": - new_node = self.new_graph.call_module( - "op_wrapper", - args=(str(node.target), split_idx, new_args, new_kwargs), - ) - else: - new_node = self.new_graph.node_copy( - node, arg_transform=lambda n: orig_to_new[n] - ) - splits.append(new_node) - - self.node_splits[node] = splits - print( - f"Replicated computation node {node.name} into {num_splits} parts" - ) - - def _handle_outputs(self) -> None: - """Handle output nodes by concatenating split outputs and cleaning up original computations.""" - assert self.new_graph is not None - output_nodes = [ - node for node in self.original_graph.nodes if node.op == "output" - ] - assert len(output_nodes) == 1, f"Expected 1 output node, found {len(output_nodes)}" - output_node = output_nodes[0] - - # Find the original computation that feeds into this output - if not output_node.args: - raise ValueError("Output node has no arguments") - original_outputs = output_node.args[0] - is_tuple = isinstance(original_outputs, tuple) - if not isinstance(original_outputs, tuple): - original_outputs = (original_outputs,) - new_outputs = [] - - for original_output in original_outputs: - if original_output in self.node_splits: - # Get all split outputs - split_outputs = self.node_splits[original_output] - - # Create concatenation node - if len(split_outputs) == 1: - # If there's only one split, no need to concatenate - concat_node = split_outputs[0] - else: - # Create concatenation node - concat_node = self.new_graph.call_function( - torch.cat, - args=(split_outputs, 0), # Concatenate along first dimension - ) - - new_outputs.append(concat_node) - print(f"Concatenated {len(split_outputs)} output splits") +from typing import Any, Dict, List, Set, Tuple, Union + + +def analyze_graph( + graph: torch.fx.Graph, batch_size: Union[int, torch.SymInt, None] = None +) -> Tuple[List[torch.fx.Node], torch.fx.Graph]: + weight_nodes = set() + splittable_inputs = [] + base_graph = torch.fx.Graph() + for node in graph.nodes: + # Skip computation nodes + if node.op != "placeholder": + continue + + # We assume the batch size is the first argument + if batch_size is None: + arg = node.meta["example_value"] + if not isinstance(arg, torch.SymInt): + raise ValueError("Batch size is not set") + batch_size = arg + elif isinstance(input_tensor := node.meta["example_value"], torch.Tensor): + shape = input_tensor.shape + if shape[0] == batch_size: + splittable_inputs.append(node) + print(f"Found splittable input: {node.name} with shape {shape}") else: - raise ValueError( - f"Original output {original_output} not found in node_splits" + weight_nodes.add(node) + print(f"Found weight tensor: {node.name} with shape {shape}") + # Copy all placeholder nodes to the new graph + base_graph.node_copy(node, arg_transform=lambda n: n) + return splittable_inputs, base_graph + + +def split_inputs( + graph: torch.fx.Graph, + splittable_inputs: List[torch.fx.Node], + split_module: str, + num_splits: int, +) -> Dict[torch.fx.Node, List[torch.fx.Node]]: + mapping = {} + for node in splittable_inputs: + mapping[node] = [] + for i in range(num_splits): + slice_node = graph.call_module( + split_module, + args=(node, i), + ) + mapping[node].append(slice_node) + return mapping + + +def split_computations( + org_graph: torch.fx.Graph, + new_graph: torch.fx.Graph, + mapping: Dict[torch.fx.Node, List[torch.fx.Node]], + nano_batch_sizes: List[torch.fx.Node], + wrapper_module: str, + num_splits: int, +): + def _transform(idx, n) -> torch.fx.Node: + if n in mapping: + return mapping[n][idx] + if isinstance(getattr(n, "meta", {}).get("example_value", None), torch.SymInt): + return nano_batch_sizes[idx] + return n + + for node in org_graph.nodes: + if node.op in ["placeholder", "output"]: + continue + splits = [] + for split_idx in range(num_splits): + if node.op == "call_module": + new_args = [_transform(split_idx, arg) for arg in node.args] + new_kwargs = { + k: _transform(split_idx, v) for k, v in node.kwargs.items() + } + new_node = new_graph.call_module( + wrapper_module, + args=(str(node.target), split_idx, new_args, new_kwargs), ) - - self.new_graph.output(tuple(new_outputs) if is_tuple else new_outputs[0]) - - def _get_split_args(self, args: Tuple, split_idx: int, split_indices: List[int]) -> Tuple: - """Get arguments for a specific split.""" - new_args = [] - - for arg in args: - if isinstance(arg, torch.fx.Node): - if arg in self.input_splits: - new_args.append(self.input_splits[arg][split_idx]) - elif arg in self.node_splits: - new_args.append(self.node_splits[arg][split_idx]) - elif arg in self.weight_nodes: - # Weight tensors are shared across splits - new_args.append(arg) - elif isinstance(arg.meta["example_value"], torch.SymInt): - new_args.append(split_indices[split_idx + 1] - split_indices[split_idx]) - else: - new_args.append(arg) else: - new_args.append(arg) - - return tuple(new_args) - - def _get_split_kwargs(self, kwargs: Dict, split_idx: int) -> Dict: - """Get keyword arguments for a specific split.""" - new_kwargs = {} - - for key, value in kwargs.items(): - if isinstance(value, torch.fx.Node): - if value in self.input_splits: - new_kwargs[key] = self.input_splits[value][split_idx] - elif value in self.node_splits: - new_kwargs[key] = self.node_splits[value][split_idx] - elif value in self.weight_nodes: - # Weight tensors are shared across splits - new_kwargs[key] = value - else: - new_kwargs[key] = value + new_node = new_graph.node_copy( + node, arg_transform=lambda n: _transform(split_idx, n) + ) + splits.append(new_node) + mapping[node] = splits + return mapping + + +def concat_outputs( + org_graph: torch.fx.Graph, + new_graph: torch.fx.Graph, + mapping: Dict[torch.fx.Node, List[torch.fx.Node]], +): + output_nodes = [node for node in org_graph.nodes if node.op == "output"] + assert len(output_nodes) == 1, f"Expected 1 output node, found {len(output_nodes)}" + output_node = output_nodes[0] + + if not output_node.args: + raise ValueError("Output node has no arguments") + original_outputs = output_node.args[0] + is_tuple = isinstance(original_outputs, tuple) + if not isinstance(original_outputs, tuple): + original_outputs = (original_outputs,) + new_outputs = [] + + for original_output in original_outputs: + if original_output in mapping: + # Get all split outputs + split_outputs = mapping[original_output] + + # Create concatenation node + if len(split_outputs) == 1: + # If there's only one split, no need to concatenate + concat_node = split_outputs[0] else: - new_kwargs[key] = value - - return new_kwargs + # Create concatenation node + concat_node = new_graph.call_function( + torch.cat, + args=(split_outputs, 0), # Concatenate along first dimension + ) - def prepare_split( - self, - input_info: InputInfo, - ) -> list[int]: - total_batch_size = input_info.batch_size - if total_batch_size == 1: - batch_sizes = [1] - split_indices = [0, input_info.num_tokens[0]] + new_outputs.append(concat_node) + print(f"Concatenated {len(split_outputs)} output splits") else: - batch_sizes = [total_batch_size // 2, total_batch_size - total_batch_size // 2] - split_indices = [0, sum(input_info.num_tokens[:total_batch_size // 2]), sum(input_info.num_tokens)] - assert self.base_graph is not None - self.new_graph = copy.deepcopy(self.base_graph) - self._init_input_splits(split_indices) - self._replicate_computations(split_indices) - self._handle_outputs() - assert self.graph_module is not None - self.graph_module.graph = self.new_graph - self.cached_config = NanoBatchSplitConfig(split_indices, batch_sizes) - from torch._dynamo.utils import lazy_format_graph_code # type: ignore - print(lazy_format_graph_code("after nano split", self.graph_module)) - return batch_sizes - - def prepare_runtime( - self, - *, - forward_hook: Optional[Callable] = None, - op_hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None, - ) -> None: - assert self.cached_config is not None - batch_sizes = self.cached_config.batch_sizes - comm_finished = [None for _ in range(len(batch_sizes))] - comp_finished = [None for _ in range(len(batch_sizes))] - - @contextlib.contextmanager - def set_stream(op_info: NanoOpInfo): - if op_info.tag == "vllm.all_reduce": - torch.cuda.set_stream(self.comm_stream) # type: ignore - comm_finished[op_info.idx] = torch.cuda.Event() # type: ignore - if comp_finished[op_info.idx] is not None: - comp_finished[op_info.idx].wait() # type: ignore - comp_finished[op_info.idx] = None - else: - torch.cuda.set_stream(self.comp_stream) # type: ignore - comp_finished[op_info.idx] = torch.cuda.Event() # type: ignore - if comm_finished[op_info.idx] is not None: - comm_finished[op_info.idx].wait() # type: ignore - comm_finished[op_info.idx] = None - try: - yield - finally: - if op_info.tag == "vllm.all_reduce": - comm_finished[op_info.idx].record() # type: ignore - else: - comp_finished[op_info.idx].record() # type: ignore - - @contextlib.contextmanager - def nvtx_mark(op_info: NanoOpInfo): - try: - with torch.cuda.nvtx.range(f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}"): - yield - except Exception as e: - print(f"Error in nvtx_mark: {e}") - raise e - - hooks = [] if op_hook is None else [op_hook] - self.op_wrapper = NanoOpWrapper(self.graph_module, [set_stream, nvtx_mark] + hooks) - if forward_hook is not None: - self.graph_module.register_forward_hook(forward_hook) - - -_split_manager = None - - -def init_split_manager_and_get_callable(graph_module: torch.fx.GraphModule) -> Callable: - global _split_manager - if _split_manager is None: - _split_manager = NanoSplitManager(graph_module) - return _split_manager.get_callable() - - - -def prepare_split(input_info: InputInfo) -> list[int]: - global _split_manager - if _split_manager is None: - raise ValueError("Split manager not initialized") - return _split_manager.prepare_split(input_info) - + raise ValueError( + f"Original output {original_output} not found in node_splits" + ) -def prepare_runtime( - *, - forward_hook: Optional[Callable] = None, - op_hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None, -) -> None: - global _split_manager - if _split_manager is None: - raise ValueError("Split manager not initialized") - _split_manager.prepare_runtime( - forward_hook=forward_hook, - op_hook=op_hook, - ) + new_graph.output(tuple(new_outputs) if is_tuple else new_outputs[0]) diff --git a/vllm/compilation/nano_utils.py b/vllm/compilation/nano_utils.py new file mode 100644 index 000000000000..a9a1283468ec --- /dev/null +++ b/vllm/compilation/nano_utils.py @@ -0,0 +1,130 @@ +import dataclasses +from contextlib import contextmanager +from typing import Callable, ContextManager, List + +import torch + +@dataclasses.dataclass +class NanoOpInfo: + submod_name: str + tag: str + idx: int + args: tuple + kwargs: dict + + +@dataclasses.dataclass +class NanoSplitConfig: + num_nano_batches: int + # Request level information + batch_sizes: List[int] + batch_indices: List[int] + # Token level information + num_tokens: List[int] + split_indices: List[int] # start/end indices of each nano batch + + +class FakeModule(torch.nn.Module): + def __init__(self, fn: Callable, **kwargs): + super().__init__() + self.fn = fn + self.kwargs = kwargs + + def forward(self, *args, **kwargs): + return self.fn(*args, **self.kwargs, **kwargs) + + +def get_split_config( + batch_size: int, + num_tokens: List[int], + cached_seqlens: List[int], +) -> NanoSplitConfig: + if batch_size == 1: + nano_batch_sizes = [1] + nano_batch_indices = [0, 1] + nano_batch_num_tokens = num_tokens.copy() + nano_batch_split_indices = [0, num_tokens[0]] + else: + nano_batch_sizes = [batch_size // 2, batch_size - batch_size // 2] + nano_batch_indices = [0, batch_size // 2, batch_size] + nano_batch_num_tokens = [ + sum(num_tokens[: batch_size // 2]), + sum(num_tokens[batch_size // 2 :]), + ] + nano_batch_split_indices = [ + 0, + nano_batch_num_tokens[0], + sum(nano_batch_num_tokens), + ] + + return NanoSplitConfig( + num_nano_batches=len(nano_batch_sizes), + batch_sizes=nano_batch_sizes, + batch_indices=nano_batch_indices, + num_tokens=nano_batch_num_tokens, + split_indices=nano_batch_split_indices, + ) + +def display_graph(graph_module: torch.fx.GraphModule, name: str): + from torch._dynamo.utils import lazy_format_graph_code # type: ignore + print(lazy_format_graph_code(name, graph_module)) + + +def split_graph_with_tags( + graph: torch.fx.GraphModule, + split_ops: list[str], + op_tags: dict[str, str], +) -> tuple[torch.fx.GraphModule]: + # split graph by ops + subgraph_id = 0 + node_to_subgraph_id = {} + subgraph_to_tag = {} + split_op_graphs = [] + for node in graph.graph.nodes: + if node.op in ("output", "placeholder"): + continue + if node.op == 'call_function' and str(node.target) in split_ops: + subgraph_id += 1 + node_to_subgraph_id[node] = subgraph_id + split_op_graphs.append(subgraph_id) + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + if (tag := op_tags.get(node.op)) is not None: + assert subgraph_to_tag[subgraph_id] is None or subgraph_to_tag[subgraph_id] == tag, \ + f"tag mismatch: {subgraph_to_tag[subgraph_id]} != {tag}" + subgraph_to_tag[subgraph_id] = tag + + split_gm = torch.fx.passes.split_module.split_module( + graph, + None, + lambda node: node_to_subgraph_id[node], + keep_original_order=True) + + names = [name for (name, _) in split_gm.named_modules()] + for name in names: + if "." in name or name == "": + continue + module = getattr(split_gm, name) + graph_id = int(name.replace("submod_", "")) + setattr(module, "tag", subgraph_to_tag.get(graph_id, "")) + + return split_gm + +def tag_graph(gm: torch.fx.GraphModule, op_tags: dict[str, str]): + submodules = [ + (name, module) + for (name, module) in gm.named_modules() + if hasattr(module, "graph") + ] + for _, module in submodules: + for node in module.graph.nodes: + if ( + node.op == "call_function" + and (tag := op_tags.get(str(node.target))) is not None + ): + assert ( + getattr(module, "tag", None) is None + or getattr(module, "tag") == tag + ), f"tag mismatch: {getattr(module, 'tag')} != {tag}" + setattr(module, "tag", tag) diff --git a/vllm/config.py b/vllm/config.py index e44030ee56a7..48a382f34fe5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -415,6 +415,8 @@ class ModelConfig: - "transformers" will use the Transformers model implementation.""" override_attention_dtype: Optional[str] = None """Override dtype for attention""" + enable_nano_split: bool = False + """Enable nano split for the model""" def compute_hash(self) -> str: """ @@ -4343,7 +4345,6 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", - "vllm.all_reduce" ] From 49269f285aa3c2b0ceb580ba1ab47c2972988e50 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Fri, 1 Aug 2025 16:24:22 -0700 Subject: [PATCH 05/22] update model runner Signed-off-by: Yi Pan --- vllm/v1/worker/gpu_model_runner.py | 128 +++-------------------------- 1 file changed, 12 insertions(+), 116 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 11350d921dd4..83929d078098 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,12 +13,12 @@ import torch.nn as nn from tqdm import tqdm +from vllm.compilation.nano_manager import prepare_nano_split import vllm.envs as envs from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter -from vllm.compilation.nano_split import InputInfo, NanoOpInfo, prepare_split, prepare_runtime from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState @@ -791,121 +791,16 @@ def _prepare_inputs( spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata) - def _split_inputs_for_nano_split( - self, - scheduler_output: "SchedulerOutput", - ) -> None: - """ - Split the input batch into nano batches and prepare - attention metadata for each nano batch. - - Args: - scheduler_output: SchedulerOutput for the whole batch. - """ - req_ids = self.input_batch.req_ids - num_reqs = len(req_ids) - num_tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] - cached_seqlens = self.input_batch.num_computed_tokens_cpu[:num_reqs].tolist() - batch_sizes = prepare_split( - InputInfo( - batch_size=num_reqs, - num_tokens=num_tokens, - cached_seqlens=cached_seqlens, - ) + def _prepare_nano_split(self, scheduler_output: "SchedulerOutput"): + from vllm.utils.nanoinfer import get_context_hooks + prepare_nano_split( + batch_size=len(self.input_batch.req_ids), + num_tokens=[scheduler_output.num_scheduled_tokens[rid] for rid in self.input_batch.req_ids], + cached_seqlens=self.input_batch.num_computed_tokens_cpu[:len(self.input_batch.req_ids)].tolist(), + get_hooks_fn=get_context_hooks, + gpu_model_runner=self, + scheduler_output=scheduler_output, ) - attn_metadatas = [] - - start_req_idx = 0 - end_req_idx = 0 - for nano_batch_size in batch_sizes: - start_req_idx = end_req_idx - end_req_idx = start_req_idx + nano_batch_size - nano_batch_req_ids = req_ids[start_req_idx:end_req_idx] - - # Gather per-request info for this group - nano_batch_num_scheduled_tokens = np.array( - [scheduler_output.num_scheduled_tokens[rid] for rid in nano_batch_req_ids], - dtype=np.int32 - ) - nano_batch_cu_num_tokens, nano_batch_arange = self._get_cumsum_and_arange(nano_batch_num_scheduled_tokens) - nano_batch_total_tokens = int(nano_batch_cu_num_tokens[-1]) - nano_batch_req_indices = np.repeat(np.arange(len(nano_batch_req_ids)), nano_batch_num_scheduled_tokens) - - # Compute positions for this group - nano_batch_positions_np = np.empty(nano_batch_total_tokens, dtype=np.int64) - np.add( - self.input_batch.num_computed_tokens_cpu[start_req_idx:end_req_idx][nano_batch_req_indices], - nano_batch_arange, - out=nano_batch_positions_np - ) - - # Prepare attention metadata for each KV cache group - nano_batch_attn_metadata = {} - for kv_cache_group_id, kv_cache_group_spec in enumerate(self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[start_req_idx:end_req_idx] - slot_mapping = blk_table.slot_mapping[:nano_batch_total_tokens] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc[start_req_idx:end_req_idx + 1] - self.query_start_loc[start_req_idx], - query_start_loc_cpu=self.query_start_loc_cpu[start_req_idx:end_req_idx + 1] - self.query_start_loc_cpu[start_req_idx], - seq_lens=self.seq_lens[start_req_idx:end_req_idx], - seq_lens_cpu=self.seq_lens_cpu[start_req_idx:end_req_idx], - num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[start_req_idx:end_req_idx], - num_reqs=nano_batch_size, - num_actual_tokens=nano_batch_total_tokens, - max_query_len=int(max(nano_batch_num_scheduled_tokens)), - block_table_tensor=blk_table_tensor, - slot_mapping=slot_mapping, - ) - - if isinstance(kv_cache_group_spec.kv_cache_spec, ChunkedLocalAttentionSpec): - common_attn_metadata = make_local_attention_virtual_batches( - kv_cache_group_spec.kv_cache_spec.attention_chunk_size, - common_attn_metadata, self.cache_config.block_size) - - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - builder = self.attn_metadata_builders[kv_cache_group_id] - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - nano_batch_num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks[kv_cache_group_id], - kv_cache_group_spec.kv_cache_spec, - builder, - ) - - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - ) - - for layer_name in kv_cache_group_spec.layer_names: - nano_batch_attn_metadata[layer_name] = attn_metadata_i - - attn_metadatas.append(nano_batch_attn_metadata) - - assert end_req_idx == num_reqs, f"invalid nano batch size: {batch_sizes}" - forward_contexts = [ - ForwardContext( - no_compile_layers=self.vllm_config.compilation_config.static_forward_context, - virtual_engine=0, - attn_metadata=attn_metadata, - dp_metadata=None, - skip_cuda_graphs=True, - ) for attn_metadata in attn_metadatas - ] - - @contextmanager - def op_hook(op_info: NanoOpInfo): - previous_context = get_forward_context() - override_forward_context(forward_contexts[op_info.idx]) - try: - yield - finally: - override_forward_context(previous_context) - - prepare_runtime(forward_hook=None, op_hook=op_hook) def _compute_cascade_attn_prefix_len( self, @@ -1513,7 +1408,8 @@ def execute_model( # If attention doesn't support CUDA Graphs for this batch, but we # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs - self._split_inputs_for_nano_split(scheduler_output) + if self.vllm_config.model_config.enable_nano_split: + self._prepare_nano_split(scheduler_output) # Run the model. # Use persistent buffers for CUDA graphs. From 39b878be58b1a511532a965d269c758c3d665687 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Sun, 3 Aug 2025 14:44:21 -0700 Subject: [PATCH 06/22] refine interface Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 1 - vllm/compilation/nano_manager.py | 26 +++--- vllm/engine/arg_utils.py | 5 +- vllm/utils/nano_split.py | 137 +++++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 14 ++- 5 files changed, 161 insertions(+), 22 deletions(-) create mode 100644 vllm/utils/nano_split.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3240b837f9f5..da94597ab012 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -14,7 +14,6 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher -from vllm.compilation.nano_utils import tag_graph import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger diff --git a/vllm/compilation/nano_manager.py b/vllm/compilation/nano_manager.py index 1c689feec509..1ce7ae654dad 100644 --- a/vllm/compilation/nano_manager.py +++ b/vllm/compilation/nano_manager.py @@ -39,6 +39,7 @@ def __init__( dev=torch.device(f"cuda:{torch.cuda.current_device()}"), sm_counts=[112] )[0] + self.hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None # Initialize the base graph tag_graph(self.graph_module, { @@ -189,12 +190,12 @@ def prepare( batch_size: int, num_tokens: List[int], cached_seqlens: List[int], - get_hooks_fn: Callable[..., Callable[[NanoOpInfo], ContextManager[None]]], - *args, - **kwargs, - ) -> None: + ) -> NanoSplitConfig: self.cached_config = get_split_config(batch_size, num_tokens, cached_seqlens) - self.hook = get_hooks_fn(self.cached_config, *args, **kwargs) + return self.cached_config + + def set_hooks(self, op_hook: Callable[[NanoOpInfo], ContextManager[None]]): + self.hook = op_hook _split_manager = None @@ -211,13 +212,16 @@ def prepare_nano_split( batch_size: int, num_tokens: List[int], cached_seqlens: List[int], - get_hooks_fn: Callable[..., Callable[[NanoOpInfo], ContextManager[None]]], - *args, - **kwargs, -) -> None: +) -> NanoSplitConfig: global _split_manager if _split_manager is None: raise ValueError("Split manager not initialized") - _split_manager.prepare( - batch_size, num_tokens, cached_seqlens, get_hooks_fn, *args, **kwargs + return _split_manager.prepare( + batch_size, num_tokens, cached_seqlens ) + +def set_op_hook(op_hook: Callable[[NanoOpInfo], ContextManager[None]]): + global _split_manager + if _split_manager is None: + raise ValueError("Split manager not initialized") + _split_manager.set_hooks(op_hook) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 019ff033eda2..c972a0b20e0f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -424,6 +424,7 @@ class EngineArgs: get_field(ModelConfig, "override_generation_config") model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype + enable_nano_split: bool = ModelConfig.enable_nano_split calculate_kv_scales: bool = CacheConfig.calculate_kv_scales @@ -540,7 +541,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["model_impl"]) model_group.add_argument("--override-attention-dtype", **model_kwargs["override_attention_dtype"]) - + model_group.add_argument("--enable-nano-split", + **model_kwargs["enable_nano_split"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) load_group = parser.add_argument_group( @@ -933,6 +935,7 @@ def create_model_config(self) -> ModelConfig: enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, + enable_nano_split=self.enable_nano_split, ) def validate_tensorizer_args(self): diff --git a/vllm/utils/nano_split.py b/vllm/utils/nano_split.py new file mode 100644 index 000000000000..e20799807df8 --- /dev/null +++ b/vllm/utils/nano_split.py @@ -0,0 +1,137 @@ +from contextlib import contextmanager + +import numpy as np +from vllm.compilation.nano_utils import NanoOpInfo +from vllm.forward_context import ( + ForwardContext, + get_forward_context, + override_forward_context, +) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + make_local_attention_virtual_batches, +) +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.compilation import nano_manager + + +def prepare_nano_split_and_set_hooks( + gpu_model_runner: "GPUModelRunner", + scheduler_output: "SchedulerOutput", +) -> None: + input_batch = gpu_model_runner.input_batch + req_ids = input_batch.req_ids + batch_size = len(req_ids) + num_tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] + cached_seqlens = input_batch.num_computed_tokens_cpu[ + :batch_size + ].tolist() + split_config = nano_manager.prepare_nano_split(batch_size, num_tokens, cached_seqlens) + + attn_metadatas = [] + start_req_idx = 0 + end_req_idx = 0 + for nano_batch_size in split_config.batch_sizes: + start_req_idx = end_req_idx + end_req_idx = start_req_idx + nano_batch_size + nano_batch_req_ids = req_ids[start_req_idx:end_req_idx] + + # Gather per-request info for this group + nano_batch_num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[rid] for rid in nano_batch_req_ids], + dtype=np.int32, + ) + nano_batch_cu_num_tokens, nano_batch_arange = ( + gpu_model_runner._get_cumsum_and_arange(nano_batch_num_scheduled_tokens) + ) + nano_batch_total_tokens = int(nano_batch_cu_num_tokens[-1]) + nano_batch_req_indices = np.repeat( + np.arange(len(nano_batch_req_ids)), nano_batch_num_scheduled_tokens + ) + + # Compute positions for this group + nano_batch_positions_np = np.empty(nano_batch_total_tokens, dtype=np.int64) + np.add( + input_batch.num_computed_tokens_cpu[ + start_req_idx:end_req_idx + ][nano_batch_req_indices], + nano_batch_arange, + out=nano_batch_positions_np, + ) + + # Prepare attention metadata for each KV cache group + nano_batch_attn_metadata = {} + for kv_cache_group_id, kv_cache_group_spec in enumerate( + gpu_model_runner.kv_cache_config.kv_cache_groups + ): + blk_table = input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[start_req_idx:end_req_idx] + slot_mapping = blk_table.slot_mapping[:nano_batch_total_tokens] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=gpu_model_runner.query_start_loc[ + start_req_idx : end_req_idx + 1 + ] + - gpu_model_runner.query_start_loc[start_req_idx], + query_start_loc_cpu=gpu_model_runner.query_start_loc_cpu[ + start_req_idx : end_req_idx + 1 + ] + - gpu_model_runner.query_start_loc_cpu[start_req_idx], + seq_lens=gpu_model_runner.seq_lens[start_req_idx:end_req_idx], + seq_lens_cpu=gpu_model_runner.seq_lens_cpu[start_req_idx:end_req_idx], + num_computed_tokens_cpu=input_batch.num_computed_tokens_cpu_tensor[ + start_req_idx:end_req_idx + ], + num_reqs=nano_batch_size, + num_actual_tokens=nano_batch_total_tokens, + max_query_len=int(max(nano_batch_num_scheduled_tokens)), + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + ) + + if isinstance(kv_cache_group_spec.kv_cache_spec, ChunkedLocalAttentionSpec): + common_attn_metadata = make_local_attention_virtual_batches( + kv_cache_group_spec.kv_cache_spec.attention_chunk_size, + common_attn_metadata, + gpu_model_runner.cache_config.block_size, + ) + + # NOTE(yi): does not support cascade attention + common_prefix_len = 0 + builder = gpu_model_runner.attn_metadata_builders[kv_cache_group_id] + attn_metadata_i = builder.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + + for layer_name in kv_cache_group_spec.layer_names: + nano_batch_attn_metadata[layer_name] = attn_metadata_i + + attn_metadatas.append(nano_batch_attn_metadata) + + assert ( + end_req_idx == batch_size + ), f"invalid nano batch size: {split_config.batch_sizes}" + forward_contexts = [ + ForwardContext( + no_compile_layers=gpu_model_runner.vllm_config.compilation_config.static_forward_context, + virtual_engine=0, + attn_metadata=attn_metadata, + dp_metadata=None, + skip_cuda_graphs=True, + ) + for attn_metadata in attn_metadatas + ] + + @contextmanager + def op_hook(op_info: NanoOpInfo): + previous_context = get_forward_context() + override_forward_context(forward_contexts[op_info.idx]) + try: + yield + finally: + override_forward_context(previous_context) + + nano_manager.set_op_hook(op_hook) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 83929d078098..631afc776453 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,7 +13,6 @@ import torch.nn as nn from tqdm import tqdm -from vllm.compilation.nano_manager import prepare_nano_split import vllm.envs as envs from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import AttentionBackend @@ -28,7 +27,7 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, ForwardContext, get_forward_context, override_forward_context, +from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase @@ -221,7 +220,8 @@ def __init__( self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and self.vllm_config.compilation_config.use_cudagraph - and not self.model_config.enforce_eager) + and not self.model_config.enforce_eager + and not self.model_config.enable_nano_split) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -792,12 +792,8 @@ def _prepare_inputs( spec_decode_common_attn_metadata) def _prepare_nano_split(self, scheduler_output: "SchedulerOutput"): - from vllm.utils.nanoinfer import get_context_hooks - prepare_nano_split( - batch_size=len(self.input_batch.req_ids), - num_tokens=[scheduler_output.num_scheduled_tokens[rid] for rid in self.input_batch.req_ids], - cached_seqlens=self.input_batch.num_computed_tokens_cpu[:len(self.input_batch.req_ids)].tolist(), - get_hooks_fn=get_context_hooks, + from vllm.utils.nano_split import prepare_nano_split_and_set_hooks + prepare_nano_split_and_set_hooks( gpu_model_runner=self, scheduler_output=scheduler_output, ) From 4970a808c5521c1ec4053aeac57b684850b663a5 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Sun, 3 Aug 2025 15:14:46 -0700 Subject: [PATCH 07/22] refine Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 4 ++-- vllm/compilation/nano_manager.py | 11 +++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index da94597ab012..f95e9d9c8650 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -14,6 +14,7 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher +from vllm.compilation import nano_manager import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger @@ -592,8 +593,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: if self.vllm_config.model_config.enable_nano_split: - from vllm.compilation.nano_manager import init_split_manager_and_get_callable - return init_split_manager_and_get_callable(self.split_gm) + return nano_manager.get_callable(self.split_gm) else: return self.split_gm diff --git a/vllm/compilation/nano_manager.py b/vllm/compilation/nano_manager.py index 1ce7ae654dad..9cd10a9c7e33 100644 --- a/vllm/compilation/nano_manager.py +++ b/vllm/compilation/nano_manager.py @@ -2,7 +2,6 @@ import copy import torch from typing import Callable, ContextManager, List, Optional -from flashinfer.green_ctx import split_device_green_ctx_by_sm_count from vllm.compilation.nano_utils import ( NanoOpInfo, @@ -33,12 +32,8 @@ def __init__( # Runtime preparation self.cached_config: Optional[NanoSplitConfig] = None - self.comm_stream: Optional[torch.Stream] = None - self.comp_stream: Optional[torch.Stream] = None - self.comp_stream, self.comm_stream = split_device_green_ctx_by_sm_count( - dev=torch.device(f"cuda:{torch.cuda.current_device()}"), - sm_counts=[112] - )[0] + self.comm_stream = torch.cuda.Stream() + self.comp_stream = torch.cuda.Stream() self.hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None # Initialize the base graph @@ -201,7 +196,7 @@ def set_hooks(self, op_hook: Callable[[NanoOpInfo], ContextManager[None]]): _split_manager = None -def init_split_manager_and_get_callable(graph_module: torch.fx.GraphModule) -> Callable: +def get_callable(graph_module: torch.fx.GraphModule) -> Callable: global _split_manager if _split_manager is None: _split_manager = NanoSplitManager(graph_module) From d00c4af3ce6edcf2c5072a347a36290e0e30d809 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Mon, 4 Aug 2025 01:28:45 -0700 Subject: [PATCH 08/22] separate nanoflow logic Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 7 +- vllm/compilation/nano_utils.py | 130 --------------- .../{nano_manager.py => nanoflow/manager.py} | 116 ++++++-------- .../split_utils.py} | 151 ++++++++++++++---- vllm/config.py | 1 + vllm/utils/nano_split.py | 96 ++++++----- vllm/v1/worker/gpu_model_runner.py | 10 +- 7 files changed, 241 insertions(+), 270 deletions(-) delete mode 100644 vllm/compilation/nano_utils.py rename vllm/compilation/{nano_manager.py => nanoflow/manager.py} (67%) rename vllm/compilation/{nano_split.py => nanoflow/split_utils.py} (51%) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f95e9d9c8650..1285d622f812 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -14,7 +14,7 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher -from vllm.compilation import nano_manager +from vllm.compilation.nanoflow import manager as nano_manager import vllm.envs as envs from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger @@ -593,7 +593,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: if self.vllm_config.model_config.enable_nano_split: - return nano_manager.get_callable(self.split_gm) + return nano_manager.get_callable( + self.split_gm, + compilation_config=self.compilation_config + ) else: return self.split_gm diff --git a/vllm/compilation/nano_utils.py b/vllm/compilation/nano_utils.py deleted file mode 100644 index a9a1283468ec..000000000000 --- a/vllm/compilation/nano_utils.py +++ /dev/null @@ -1,130 +0,0 @@ -import dataclasses -from contextlib import contextmanager -from typing import Callable, ContextManager, List - -import torch - -@dataclasses.dataclass -class NanoOpInfo: - submod_name: str - tag: str - idx: int - args: tuple - kwargs: dict - - -@dataclasses.dataclass -class NanoSplitConfig: - num_nano_batches: int - # Request level information - batch_sizes: List[int] - batch_indices: List[int] - # Token level information - num_tokens: List[int] - split_indices: List[int] # start/end indices of each nano batch - - -class FakeModule(torch.nn.Module): - def __init__(self, fn: Callable, **kwargs): - super().__init__() - self.fn = fn - self.kwargs = kwargs - - def forward(self, *args, **kwargs): - return self.fn(*args, **self.kwargs, **kwargs) - - -def get_split_config( - batch_size: int, - num_tokens: List[int], - cached_seqlens: List[int], -) -> NanoSplitConfig: - if batch_size == 1: - nano_batch_sizes = [1] - nano_batch_indices = [0, 1] - nano_batch_num_tokens = num_tokens.copy() - nano_batch_split_indices = [0, num_tokens[0]] - else: - nano_batch_sizes = [batch_size // 2, batch_size - batch_size // 2] - nano_batch_indices = [0, batch_size // 2, batch_size] - nano_batch_num_tokens = [ - sum(num_tokens[: batch_size // 2]), - sum(num_tokens[batch_size // 2 :]), - ] - nano_batch_split_indices = [ - 0, - nano_batch_num_tokens[0], - sum(nano_batch_num_tokens), - ] - - return NanoSplitConfig( - num_nano_batches=len(nano_batch_sizes), - batch_sizes=nano_batch_sizes, - batch_indices=nano_batch_indices, - num_tokens=nano_batch_num_tokens, - split_indices=nano_batch_split_indices, - ) - -def display_graph(graph_module: torch.fx.GraphModule, name: str): - from torch._dynamo.utils import lazy_format_graph_code # type: ignore - print(lazy_format_graph_code(name, graph_module)) - - -def split_graph_with_tags( - graph: torch.fx.GraphModule, - split_ops: list[str], - op_tags: dict[str, str], -) -> tuple[torch.fx.GraphModule]: - # split graph by ops - subgraph_id = 0 - node_to_subgraph_id = {} - subgraph_to_tag = {} - split_op_graphs = [] - for node in graph.graph.nodes: - if node.op in ("output", "placeholder"): - continue - if node.op == 'call_function' and str(node.target) in split_ops: - subgraph_id += 1 - node_to_subgraph_id[node] = subgraph_id - split_op_graphs.append(subgraph_id) - subgraph_id += 1 - else: - node_to_subgraph_id[node] = subgraph_id - if (tag := op_tags.get(node.op)) is not None: - assert subgraph_to_tag[subgraph_id] is None or subgraph_to_tag[subgraph_id] == tag, \ - f"tag mismatch: {subgraph_to_tag[subgraph_id]} != {tag}" - subgraph_to_tag[subgraph_id] = tag - - split_gm = torch.fx.passes.split_module.split_module( - graph, - None, - lambda node: node_to_subgraph_id[node], - keep_original_order=True) - - names = [name for (name, _) in split_gm.named_modules()] - for name in names: - if "." in name or name == "": - continue - module = getattr(split_gm, name) - graph_id = int(name.replace("submod_", "")) - setattr(module, "tag", subgraph_to_tag.get(graph_id, "")) - - return split_gm - -def tag_graph(gm: torch.fx.GraphModule, op_tags: dict[str, str]): - submodules = [ - (name, module) - for (name, module) in gm.named_modules() - if hasattr(module, "graph") - ] - for _, module in submodules: - for node in module.graph.nodes: - if ( - node.op == "call_function" - and (tag := op_tags.get(str(node.target))) is not None - ): - assert ( - getattr(module, "tag", None) is None - or getattr(module, "tag") == tag - ), f"tag mismatch: {getattr(module, 'tag')} != {tag}" - setattr(module, "tag", tag) diff --git a/vllm/compilation/nano_manager.py b/vllm/compilation/nanoflow/manager.py similarity index 67% rename from vllm/compilation/nano_manager.py rename to vllm/compilation/nanoflow/manager.py index 9cd10a9c7e33..2c6f9c227f4a 100644 --- a/vllm/compilation/nano_manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -3,7 +3,11 @@ import torch from typing import Callable, ContextManager, List, Optional -from vllm.compilation.nano_utils import ( +import torch.fx.graph_module + +from vllm.compilation.nanoflow.split_utils import ( + analyze_graph, + split_graph, NanoOpInfo, NanoSplitConfig, FakeModule, @@ -11,69 +15,61 @@ get_split_config, tag_graph, ) -from vllm.compilation.nano_split import ( - analyze_graph, - concat_outputs, - split_computations, - split_inputs, -) +from vllm.config import CompilationConfig class NanoSplitManager: def __init__( - self, graph_module: torch.fx.GraphModule, max_nano_splits: int = 2 + self, graph_module: torch.fx.GraphModule, compilation_config: CompilationConfig, ) -> None: - self.graph_module = graph_module + self.original_graph_module = graph_module self.original_graph = graph_module.graph # Nano split preparation - self.max_nano_splits = max_nano_splits - self.new_graphs = {1: self.original_graph} + # NOTE(yi): move this to compilation config + self.max_nano_splits = 2 + # Initialize the base graph + tag_graph( + self.original_graph_module, + { + "vllm.unified_attention": "attention", + "vllm.unified_attention_with_output": "attention", + "vllm.all_reduce": "all_reduce", + }, + ) + self.graph_modules = {1: self.original_graph_module} # Runtime preparation self.cached_config: Optional[NanoSplitConfig] = None self.comm_stream = torch.cuda.Stream() self.comp_stream = torch.cuda.Stream() self.hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None + self.get_bs_fn = "get_batch_size" + self.split_fn = "split_input" + self.wrapper_fn = "op_wrapper" + setattr(self.original_graph_module, self.get_bs_fn, None) + setattr(self.original_graph_module, self.split_fn, None) + setattr(self.original_graph_module, self.wrapper_fn, None) - # Initialize the base graph - tag_graph(self.graph_module, { - "vllm.unified_attention": "attention", - "vllm.unified_attention_with_output": "attention", - "vllm.all_reduce": "all_reduce", - }) splittable_inputs, base_graph = analyze_graph(self.original_graph) - for num_splits in range(2, max_nano_splits + 1): + for num_splits in range(2, self.max_nano_splits + 1): new_graph = copy.deepcopy(base_graph) - nano_batch_sizes = [] - for i in range(num_splits): - nano_batch_sizes.append( - new_graph.call_module( - "get_batch_size", - args=(i,), - kwargs={}, - ) - ) - mapping = split_inputs( - new_graph, splittable_inputs, "split_input", num_splits - ) - split_computations( - self.original_graph, - new_graph, - mapping, - nano_batch_sizes, - "op_wrapper", - num_splits, - ) - concat_outputs( + split_graph( self.original_graph, - new_graph, - mapping, + out=new_graph, + splittable_inputs=splittable_inputs, + num_splits=num_splits, + get_bs_fn=self.get_bs_fn, + split_fn=self.split_fn, + wrapper_fn=self.wrapper_fn, ) - self.new_graphs[num_splits] = new_graph - print(new_graph) - self.graph_module.graph = new_graph - display_graph(self.graph_module, f"after nano split {num_splits}") + new_graph_module = torch.fx.GraphModule(self.original_graph_module, new_graph) + for name, _ in self.original_graph_module.named_modules(): + if "." in name or name == "": + continue + torch.fx.graph_module._copy_attr(self.original_graph_module, new_graph_module, name) + self.graph_modules[num_splits] = new_graph_module + @staticmethod def get_batch_size(idx: int, cached_config: NanoSplitConfig): @@ -107,13 +103,9 @@ def op_wrapper( def get_callable(self) -> Callable: def _forward(*args, **kwargs): if self.cached_config is None: - self.graph_module.graph = self.original_graph - return self.graph_module(*args, **kwargs) + return self.original_graph_module(*args, **kwargs) num_nano_batches = self.cached_config.num_nano_batches - # NOTE(yi): This can be time consuming - if self.graph_module.graph != self.new_graphs[num_nano_batches]: - self.graph_module.graph = self.new_graphs[num_nano_batches] comm_finished = [None for _ in range(num_nano_batches)] comp_finished = [None for _ in range(num_nano_batches)] @@ -154,28 +146,25 @@ def nvtx_mark(op_info: NanoOpInfo): assert self.hook is not None op_wrapper = FakeModule( NanoSplitManager.op_wrapper, - gm=self.graph_module, + gm=self.graph_modules[num_nano_batches], hooks=[ set_stream, nvtx_mark, self.hook, ], ) - setattr(self.graph_module, "op_wrapper", op_wrapper) get_batch_size = FakeModule( NanoSplitManager.get_batch_size, cached_config=self.cached_config, ) - setattr(self.graph_module, "get_batch_size", get_batch_size) split_input = FakeModule( NanoSplitManager.split_input, cached_config=self.cached_config, ) - setattr(self.graph_module, "split_input", split_input) - output = self.graph_module(*args, **kwargs) - delattr(self.graph_module, "op_wrapper") - delattr(self.graph_module, "get_batch_size") - delattr(self.graph_module, "split_input") + setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, op_wrapper) + setattr(self.graph_modules[num_nano_batches], self.get_bs_fn, get_batch_size) + setattr(self.graph_modules[num_nano_batches], self.split_fn, split_input) + output = self.graph_modules[num_nano_batches](*args, **kwargs) return output return _forward @@ -188,7 +177,7 @@ def prepare( ) -> NanoSplitConfig: self.cached_config = get_split_config(batch_size, num_tokens, cached_seqlens) return self.cached_config - + def set_hooks(self, op_hook: Callable[[NanoOpInfo], ContextManager[None]]): self.hook = op_hook @@ -196,10 +185,10 @@ def set_hooks(self, op_hook: Callable[[NanoOpInfo], ContextManager[None]]): _split_manager = None -def get_callable(graph_module: torch.fx.GraphModule) -> Callable: +def get_callable(graph_module: torch.fx.GraphModule, compilation_config: CompilationConfig) -> Callable: global _split_manager if _split_manager is None: - _split_manager = NanoSplitManager(graph_module) + _split_manager = NanoSplitManager(graph_module, compilation_config) return _split_manager.get_callable() @@ -211,9 +200,8 @@ def prepare_nano_split( global _split_manager if _split_manager is None: raise ValueError("Split manager not initialized") - return _split_manager.prepare( - batch_size, num_tokens, cached_seqlens - ) + return _split_manager.prepare(batch_size, num_tokens, cached_seqlens) + def set_op_hook(op_hook: Callable[[NanoOpInfo], ContextManager[None]]): global _split_manager diff --git a/vllm/compilation/nano_split.py b/vllm/compilation/nanoflow/split_utils.py similarity index 51% rename from vllm/compilation/nano_split.py rename to vllm/compilation/nanoflow/split_utils.py index 07c13e464c35..ad6d61919f9a 100644 --- a/vllm/compilation/nano_split.py +++ b/vllm/compilation/nanoflow/split_utils.py @@ -1,5 +1,93 @@ import torch -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Callable, List, Tuple, Union +import dataclasses + + +@dataclasses.dataclass +class NanoOpInfo: + submod_name: str + tag: str + idx: int + args: tuple + kwargs: dict + + +@dataclasses.dataclass +class NanoSplitConfig: + num_nano_batches: int + # Request level information + batch_sizes: List[int] + batch_indices: List[int] + # Token level information + num_tokens: List[int] + split_indices: List[int] # start/end indices of each nano batch + + +class FakeModule(torch.nn.Module): + def __init__(self, fn: Callable, **kwargs): + super().__init__() + self.fn = fn + self.kwargs = kwargs + + def forward(self, *args, **kwargs): + return self.fn(*args, **self.kwargs, **kwargs) + + +def get_split_config( + batch_size: int, + num_tokens: List[int], + cached_seqlens: List[int], +) -> NanoSplitConfig: + if batch_size == 1: + nano_batch_sizes = [1] + nano_batch_indices = [0, 1] + nano_batch_num_tokens = num_tokens.copy() + nano_batch_split_indices = [0, num_tokens[0]] + else: + nano_batch_sizes = [batch_size // 2, batch_size - batch_size // 2] + nano_batch_indices = [0, batch_size // 2, batch_size] + nano_batch_num_tokens = [ + sum(num_tokens[: batch_size // 2]), + sum(num_tokens[batch_size // 2 :]), + ] + nano_batch_split_indices = [ + 0, + nano_batch_num_tokens[0], + sum(nano_batch_num_tokens), + ] + + return NanoSplitConfig( + num_nano_batches=len(nano_batch_sizes), + batch_sizes=nano_batch_sizes, + batch_indices=nano_batch_indices, + num_tokens=nano_batch_num_tokens, + split_indices=nano_batch_split_indices, + ) + + +def display_graph(graph_module: torch.fx.GraphModule, name: str): + from torch._dynamo.utils import lazy_format_graph_code # type: ignore + + print(lazy_format_graph_code(name, graph_module)) + + +def tag_graph(gm: torch.fx.GraphModule, op_tags: dict[str, str]): + submodules = [ + (name, module) + for (name, module) in gm.named_modules() + if hasattr(module, "graph") + ] + for _, module in submodules: + for node in module.graph.nodes: + if ( + node.op == "call_function" + and (tag := op_tags.get(str(node.target))) is not None + ): + assert ( + getattr(module, "tag", None) is None + or getattr(module, "tag") == tag + ), f"tag mismatch: {getattr(module, 'tag')} != {tag}" + setattr(module, "tag", tag) def analyze_graph( @@ -32,32 +120,37 @@ def analyze_graph( return splittable_inputs, base_graph -def split_inputs( +def split_graph( graph: torch.fx.Graph, + *, + out: torch.fx.Graph, splittable_inputs: List[torch.fx.Node], - split_module: str, num_splits: int, -) -> Dict[torch.fx.Node, List[torch.fx.Node]]: + get_bs_fn: str, + split_fn: str, + wrapper_fn: str, +) -> torch.fx.Graph: mapping = {} + nano_batch_sizes = [] + + # Step 1: Get nano batch sizes and split inputs + for i in range(num_splits): + nano_batch_sizes.append( + out.call_module( + get_bs_fn, + args=(i,), + ) + ) for node in splittable_inputs: mapping[node] = [] for i in range(num_splits): - slice_node = graph.call_module( - split_module, + slice_node = out.call_module( + split_fn, args=(node, i), ) mapping[node].append(slice_node) - return mapping - - -def split_computations( - org_graph: torch.fx.Graph, - new_graph: torch.fx.Graph, - mapping: Dict[torch.fx.Node, List[torch.fx.Node]], - nano_batch_sizes: List[torch.fx.Node], - wrapper_module: str, - num_splits: int, -): + + # Step 2: Split computation nodes def _transform(idx, n) -> torch.fx.Node: if n in mapping: return mapping[n][idx] @@ -65,7 +158,7 @@ def _transform(idx, n) -> torch.fx.Node: return nano_batch_sizes[idx] return n - for node in org_graph.nodes: + for node in graph.nodes: if node.op in ["placeholder", "output"]: continue splits = [] @@ -75,28 +168,21 @@ def _transform(idx, n) -> torch.fx.Node: new_kwargs = { k: _transform(split_idx, v) for k, v in node.kwargs.items() } - new_node = new_graph.call_module( - wrapper_module, + new_node = out.call_module( + wrapper_fn, args=(str(node.target), split_idx, new_args, new_kwargs), ) else: - new_node = new_graph.node_copy( + new_node = out.node_copy( node, arg_transform=lambda n: _transform(split_idx, n) ) splits.append(new_node) mapping[node] = splits - return mapping - -def concat_outputs( - org_graph: torch.fx.Graph, - new_graph: torch.fx.Graph, - mapping: Dict[torch.fx.Node, List[torch.fx.Node]], -): - output_nodes = [node for node in org_graph.nodes if node.op == "output"] + # Step 3: Concatenate outputs + output_nodes = [node for node in graph.nodes if node.op == "output"] assert len(output_nodes) == 1, f"Expected 1 output node, found {len(output_nodes)}" output_node = output_nodes[0] - if not output_node.args: raise ValueError("Output node has no arguments") original_outputs = output_node.args[0] @@ -116,7 +202,7 @@ def concat_outputs( concat_node = split_outputs[0] else: # Create concatenation node - concat_node = new_graph.call_function( + concat_node = out.call_function( torch.cat, args=(split_outputs, 0), # Concatenate along first dimension ) @@ -128,4 +214,5 @@ def concat_outputs( f"Original output {original_output} not found in node_splits" ) - new_graph.output(tuple(new_outputs) if is_tuple else new_outputs[0]) + out.output(tuple(new_outputs) if is_tuple else new_outputs[0]) + return out diff --git a/vllm/config.py b/vllm/config.py index 48a382f34fe5..a51118ba6c02 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -445,6 +445,7 @@ def compute_hash(self) -> str: factors.append(self.override_generation_config) factors.append(self.rope_scaling) factors.append(self.rope_theta) + factors.append(self.enable_nano_split) # hf_config can control how the model looks! factors.append(self.hf_config.to_json_string()) str_factors = str(factors) diff --git a/vllm/utils/nano_split.py b/vllm/utils/nano_split.py index e20799807df8..e1bd7bd32f42 100644 --- a/vllm/utils/nano_split.py +++ b/vllm/utils/nano_split.py @@ -1,34 +1,61 @@ from contextlib import contextmanager +from typing import Optional import numpy as np -from vllm.compilation.nano_utils import NanoOpInfo +import torch +from vllm.compilation.nanoflow.split_utils import NanoOpInfo from vllm.forward_context import ( ForwardContext, get_forward_context, override_forward_context, ) from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches, ) from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec -from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.compilation import nano_manager +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm.compilation.nanoflow import manager as nano_manager + + +def _get_cumsum_and_arange( + num_tokens: np.ndarray, + cumsum_dtype: Optional[np.dtype] = None, +) -> tuple[np.ndarray, np.ndarray]: + cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) + arange = np.arange(total_num_tokens) - cumsums_offsets + return cu_num_tokens, arange def prepare_nano_split_and_set_hooks( - gpu_model_runner: "GPUModelRunner", - scheduler_output: "SchedulerOutput", + scheduler_output: SchedulerOutput, + input_batch: InputBatch, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, ) -> None: - input_batch = gpu_model_runner.input_batch + prev_forward_context = get_forward_context() req_ids = input_batch.req_ids batch_size = len(req_ids) - num_tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] - cached_seqlens = input_batch.num_computed_tokens_cpu[ - :batch_size - ].tolist() - split_config = nano_manager.prepare_nano_split(batch_size, num_tokens, cached_seqlens) + tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] + num_scheduled_tokens = torch.tensor(tokens, dtype=torch.int32) + cached_seqlens = input_batch.num_computed_tokens_cpu[:batch_size].tolist() + split_config = nano_manager.prepare_nano_split( + batch_size, tokens, cached_seqlens + ) + + cu_num_tokens = torch.cumsum(num_scheduled_tokens, dim=0, dtype=torch.int32) + query_start_loc_cpu = torch.zeros(batch_size + 1, dtype=torch.int32) + query_start_loc_cpu[0] = 0 + query_start_loc_cpu[1:batch_size + 1] = cu_num_tokens + seq_lens_cpu = torch.zeros(batch_size, dtype=torch.int32) + seq_lens_cpu[:batch_size] = ( + input_batch.num_computed_tokens_cpu_tensor[:batch_size] + num_scheduled_tokens + ) + query_start_loc = query_start_loc_cpu.to(input_batch.device) + seq_lens = seq_lens_cpu.to(input_batch.device) attn_metadatas = [] start_req_idx = 0 @@ -43,8 +70,8 @@ def prepare_nano_split_and_set_hooks( [scheduler_output.num_scheduled_tokens[rid] for rid in nano_batch_req_ids], dtype=np.int32, ) - nano_batch_cu_num_tokens, nano_batch_arange = ( - gpu_model_runner._get_cumsum_and_arange(nano_batch_num_scheduled_tokens) + nano_batch_cu_num_tokens, nano_batch_arange = _get_cumsum_and_arange( + nano_batch_num_scheduled_tokens ) nano_batch_total_tokens = int(nano_batch_cu_num_tokens[-1]) nano_batch_req_indices = np.repeat( @@ -54,9 +81,9 @@ def prepare_nano_split_and_set_hooks( # Compute positions for this group nano_batch_positions_np = np.empty(nano_batch_total_tokens, dtype=np.int64) np.add( - input_batch.num_computed_tokens_cpu[ - start_req_idx:end_req_idx - ][nano_batch_req_indices], + input_batch.num_computed_tokens_cpu[start_req_idx:end_req_idx][ + nano_batch_req_indices + ], nano_batch_arange, out=nano_batch_positions_np, ) @@ -64,23 +91,23 @@ def prepare_nano_split_and_set_hooks( # Prepare attention metadata for each KV cache group nano_batch_attn_metadata = {} for kv_cache_group_id, kv_cache_group_spec in enumerate( - gpu_model_runner.kv_cache_config.kv_cache_groups + kv_cache_config.kv_cache_groups ): blk_table = input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor()[start_req_idx:end_req_idx] slot_mapping = blk_table.slot_mapping[:nano_batch_total_tokens] common_attn_metadata = CommonAttentionMetadata( - query_start_loc=gpu_model_runner.query_start_loc[ + query_start_loc=query_start_loc[ start_req_idx : end_req_idx + 1 ] - - gpu_model_runner.query_start_loc[start_req_idx], - query_start_loc_cpu=gpu_model_runner.query_start_loc_cpu[ + - query_start_loc[start_req_idx], + query_start_loc_cpu=query_start_loc_cpu[ start_req_idx : end_req_idx + 1 ] - - gpu_model_runner.query_start_loc_cpu[start_req_idx], - seq_lens=gpu_model_runner.seq_lens[start_req_idx:end_req_idx], - seq_lens_cpu=gpu_model_runner.seq_lens_cpu[start_req_idx:end_req_idx], + - query_start_loc_cpu[start_req_idx], + seq_lens=seq_lens[start_req_idx:end_req_idx], + seq_lens_cpu=seq_lens_cpu[start_req_idx:end_req_idx], num_computed_tokens_cpu=input_batch.num_computed_tokens_cpu_tensor[ start_req_idx:end_req_idx ], @@ -91,16 +118,9 @@ def prepare_nano_split_and_set_hooks( slot_mapping=slot_mapping, ) - if isinstance(kv_cache_group_spec.kv_cache_spec, ChunkedLocalAttentionSpec): - common_attn_metadata = make_local_attention_virtual_batches( - kv_cache_group_spec.kv_cache_spec.attention_chunk_size, - common_attn_metadata, - gpu_model_runner.cache_config.block_size, - ) - - # NOTE(yi): does not support cascade attention + # NOTE(yi): does not support chunked local attention or cascade attention common_prefix_len = 0 - builder = gpu_model_runner.attn_metadata_builders[kv_cache_group_id] + builder = attn_metadata_builders[kv_cache_group_id] attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, @@ -116,11 +136,11 @@ def prepare_nano_split_and_set_hooks( ), f"invalid nano batch size: {split_config.batch_sizes}" forward_contexts = [ ForwardContext( - no_compile_layers=gpu_model_runner.vllm_config.compilation_config.static_forward_context, - virtual_engine=0, + no_compile_layers=prev_forward_context.no_compile_layers, attn_metadata=attn_metadata, - dp_metadata=None, - skip_cuda_graphs=True, + virtual_engine=prev_forward_context.virtual_engine, + dp_metadata=prev_forward_context.dp_metadata, + skip_cuda_graphs=prev_forward_context.skip_cuda_graphs, ) for attn_metadata in attn_metadatas ] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 631afc776453..3d430f59d53a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -45,6 +45,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) +from vllm.utils.nano_split import prepare_nano_split_and_set_hooks from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, @@ -792,10 +793,11 @@ def _prepare_inputs( spec_decode_common_attn_metadata) def _prepare_nano_split(self, scheduler_output: "SchedulerOutput"): - from vllm.utils.nano_split import prepare_nano_split_and_set_hooks prepare_nano_split_and_set_hooks( - gpu_model_runner=self, scheduler_output=scheduler_output, + input_batch=self.input_batch, + attn_metadata_builders=self.attn_metadata_builders, + kv_cache_config=self.kv_cache_config ) def _compute_cascade_attn_prefix_len( @@ -1404,8 +1406,6 @@ def execute_model( # If attention doesn't support CUDA Graphs for this batch, but we # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs - if self.vllm_config.model_config.enable_nano_split: - self._prepare_nano_split(scheduler_output) # Run the model. # Use persistent buffers for CUDA graphs. @@ -1417,6 +1417,8 @@ def execute_model( skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) + if self.vllm_config.model_config.enable_nano_split: + self._prepare_nano_split(scheduler_output) model_output = self.model( input_ids=input_ids, From 99306045e992ea09c19bec41ddfa993713c4cfc4 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Thu, 7 Aug 2025 14:38:58 -0700 Subject: [PATCH 09/22] implement auto-on/off logic and fix attn metadata split Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 6 +-- vllm/compilation/nanoflow/manager.py | 23 ++++++----- vllm/compilation/nanoflow/split_utils.py | 51 ++++++++++++++---------- vllm/config.py | 12 ++++-- vllm/engine/arg_utils.py | 16 ++++++-- vllm/utils/nano_split.py | 12 +++--- vllm/v1/worker/gpu_model_runner.py | 4 +- 7 files changed, 76 insertions(+), 48 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1285d622f812..b3eda946458a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -549,7 +549,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.graph = graph self.configure_post_pass() - if self.vllm_config.model_config.enable_nano_split: + if self.vllm_config.model_config.enable_nano_batch_split: self.split_gm, self.piecewise_graphs = split_graph(graph, [ "vllm.all_reduce" ]) @@ -592,10 +592,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: - if self.vllm_config.model_config.enable_nano_split: + if self.vllm_config.model_config.enable_nano_batch_split: return nano_manager.get_callable( self.split_gm, - compilation_config=self.compilation_config + vllm_config=self.vllm_config ) else: return self.split_gm diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index 2c6f9c227f4a..00dd0246b25c 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -11,23 +11,22 @@ NanoOpInfo, NanoSplitConfig, FakeModule, - display_graph, get_split_config, tag_graph, ) -from vllm.config import CompilationConfig +from vllm.config import VllmConfig class NanoSplitManager: def __init__( - self, graph_module: torch.fx.GraphModule, compilation_config: CompilationConfig, + self, graph_module: torch.fx.GraphModule, vllm_config: VllmConfig, ) -> None: self.original_graph_module = graph_module self.original_graph = graph_module.graph # Nano split preparation - # NOTE(yi): move this to compilation config - self.max_nano_splits = 2 + self.min_nano_split_tokens = vllm_config.model_config.min_nano_split_tokens + self.max_num_nano_batches = vllm_config.model_config.max_num_nano_batches # Initialize the base graph tag_graph( self.original_graph_module, @@ -52,7 +51,7 @@ def __init__( setattr(self.original_graph_module, self.wrapper_fn, None) splittable_inputs, base_graph = analyze_graph(self.original_graph) - for num_splits in range(2, self.max_nano_splits + 1): + for num_splits in range(2, self.max_num_nano_batches + 1): new_graph = copy.deepcopy(base_graph) split_graph( self.original_graph, @@ -175,7 +174,13 @@ def prepare( num_tokens: List[int], cached_seqlens: List[int], ) -> NanoSplitConfig: - self.cached_config = get_split_config(batch_size, num_tokens, cached_seqlens) + self.cached_config = get_split_config( + batch_size, + num_tokens, + cached_seqlens, + self.max_num_nano_batches, + self.min_nano_split_tokens, + ) return self.cached_config def set_hooks(self, op_hook: Callable[[NanoOpInfo], ContextManager[None]]): @@ -185,10 +190,10 @@ def set_hooks(self, op_hook: Callable[[NanoOpInfo], ContextManager[None]]): _split_manager = None -def get_callable(graph_module: torch.fx.GraphModule, compilation_config: CompilationConfig) -> Callable: +def get_callable(graph_module: torch.fx.GraphModule, vllm_config: VllmConfig) -> Callable: global _split_manager if _split_manager is None: - _split_manager = NanoSplitManager(graph_module, compilation_config) + _split_manager = NanoSplitManager(graph_module, vllm_config) return _split_manager.get_callable() diff --git a/vllm/compilation/nanoflow/split_utils.py b/vllm/compilation/nanoflow/split_utils.py index ad6d61919f9a..1c7cec1550a4 100644 --- a/vllm/compilation/nanoflow/split_utils.py +++ b/vllm/compilation/nanoflow/split_utils.py @@ -1,3 +1,4 @@ +import itertools import torch from typing import Callable, List, Tuple, Union import dataclasses @@ -37,27 +38,36 @@ def get_split_config( batch_size: int, num_tokens: List[int], cached_seqlens: List[int], + max_num_nano_batches: int, + min_nano_split_tokens: int, ) -> NanoSplitConfig: - if batch_size == 1: - nano_batch_sizes = [1] - nano_batch_indices = [0, 1] - nano_batch_num_tokens = num_tokens.copy() - nano_batch_split_indices = [0, num_tokens[0]] - else: - nano_batch_sizes = [batch_size // 2, batch_size - batch_size // 2] - nano_batch_indices = [0, batch_size // 2, batch_size] - nano_batch_num_tokens = [ - sum(num_tokens[: batch_size // 2]), - sum(num_tokens[batch_size // 2 :]), - ] - nano_batch_split_indices = [ - 0, - nano_batch_num_tokens[0], - sum(nano_batch_num_tokens), - ] + token_batch_size = sum(num_tokens) + num_nano_batches = min( + batch_size, + (token_batch_size + min_nano_split_tokens - 1) // min_nano_split_tokens, + max_num_nano_batches, + ) + nano_batch_split_indices = [0] + nano_batch_indices = [0] + nano_batch_sizes = [] + nano_batch_num_tokens = [] + prefix_sum = [0] + list(itertools.accumulate(num_tokens)) + remaining_batches = num_nano_batches - len(nano_batch_split_indices) + 1 + remaining_tokens = token_batch_size + while remaining_batches > 0: + next_index = min( + range(nano_batch_indices[-1], batch_size - remaining_batches + 1), + key=lambda x: abs(prefix_sum[x + 1] - prefix_sum[nano_batch_indices[-1]] - remaining_tokens / remaining_batches) + ) + nano_batch_indices.append(next_index + 1) + nano_batch_split_indices.append(prefix_sum[next_index + 1]) + nano_batch_sizes.append(nano_batch_indices[-1] - nano_batch_indices[-2]) + nano_batch_num_tokens.append(nano_batch_split_indices[-1] - nano_batch_split_indices[-2]) + remaining_tokens = token_batch_size - prefix_sum[next_index + 1] + remaining_batches -= 1 return NanoSplitConfig( - num_nano_batches=len(nano_batch_sizes), + num_nano_batches=num_nano_batches, batch_sizes=nano_batch_sizes, batch_indices=nano_batch_indices, num_tokens=nano_batch_num_tokens, @@ -111,10 +121,8 @@ def analyze_graph( shape = input_tensor.shape if shape[0] == batch_size: splittable_inputs.append(node) - print(f"Found splittable input: {node.name} with shape {shape}") else: weight_nodes.add(node) - print(f"Found weight tensor: {node.name} with shape {shape}") # Copy all placeholder nodes to the new graph base_graph.node_copy(node, arg_transform=lambda n: n) return splittable_inputs, base_graph @@ -149,7 +157,7 @@ def split_graph( args=(node, i), ) mapping[node].append(slice_node) - + # Step 2: Split computation nodes def _transform(idx, n) -> torch.fx.Node: if n in mapping: @@ -208,7 +216,6 @@ def _transform(idx, n) -> torch.fx.Node: ) new_outputs.append(concat_node) - print(f"Concatenated {len(split_outputs)} output splits") else: raise ValueError( f"Original output {original_output} not found in node_splits" diff --git a/vllm/config.py b/vllm/config.py index a51118ba6c02..dd2dad522b14 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -415,8 +415,12 @@ class ModelConfig: - "transformers" will use the Transformers model implementation.""" override_attention_dtype: Optional[str] = None """Override dtype for attention""" - enable_nano_split: bool = False - """Enable nano split for the model""" + enable_nano_batch_split: bool = False + """Enable spliting the input batch into nano-batches for intra-device parallelism""" + max_num_nano_batches: int = 2 + """Maximum number of nano-batches to split the input batch into""" + min_nano_split_tokens: int = 512 + """Minimum number of tokens to split the input batch""" def compute_hash(self) -> str: """ @@ -445,7 +449,9 @@ def compute_hash(self) -> str: factors.append(self.override_generation_config) factors.append(self.rope_scaling) factors.append(self.rope_theta) - factors.append(self.enable_nano_split) + factors.append(self.enable_nano_batch_split) + factors.append(self.max_num_nano_batches) + factors.append(self.min_nano_split_tokens) # hf_config can control how the model looks! factors.append(self.hf_config.to_json_string()) str_factors = str(factors) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c972a0b20e0f..b8a94743581b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -424,7 +424,9 @@ class EngineArgs: get_field(ModelConfig, "override_generation_config") model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype - enable_nano_split: bool = ModelConfig.enable_nano_split + enable_nano_batch_split: bool = ModelConfig.enable_nano_batch_split + max_num_nano_batches: int = ModelConfig.max_num_nano_batches + min_nano_split_tokens: int = ModelConfig.min_nano_split_tokens calculate_kv_scales: bool = CacheConfig.calculate_kv_scales @@ -541,8 +543,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["model_impl"]) model_group.add_argument("--override-attention-dtype", **model_kwargs["override_attention_dtype"]) - model_group.add_argument("--enable-nano-split", - **model_kwargs["enable_nano_split"]) + model_group.add_argument("--enable-nano-batch-split", + **model_kwargs["enable_nano_batch_split"]) + model_group.add_argument("--max-num-nano-batches", + **model_kwargs["max_num_nano_batches"]) + model_group.add_argument("--min-nano-split-tokens", + **model_kwargs["min_nano_split_tokens"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) load_group = parser.add_argument_group( @@ -935,7 +941,9 @@ def create_model_config(self) -> ModelConfig: enable_sleep_mode=self.enable_sleep_mode, model_impl=self.model_impl, override_attention_dtype=self.override_attention_dtype, - enable_nano_split=self.enable_nano_split, + enable_nano_batch_split=self.enable_nano_batch_split, + max_num_nano_batches=self.max_num_nano_batches, + min_nano_split_tokens=self.min_nano_split_tokens, ) def validate_tensorizer_args(self): diff --git a/vllm/utils/nano_split.py b/vllm/utils/nano_split.py index e1bd7bd32f42..fc8986423e07 100644 --- a/vllm/utils/nano_split.py +++ b/vllm/utils/nano_split.py @@ -60,10 +60,12 @@ def prepare_nano_split_and_set_hooks( attn_metadatas = [] start_req_idx = 0 end_req_idx = 0 - for nano_batch_size in split_config.batch_sizes: - start_req_idx = end_req_idx - end_req_idx = start_req_idx + nano_batch_size + for nano_batch_idx in range(split_config.num_nano_batches): + start_req_idx = split_config.batch_indices[nano_batch_idx] + end_req_idx = split_config.batch_indices[nano_batch_idx + 1] nano_batch_req_ids = req_ids[start_req_idx:end_req_idx] + start_token_idx = split_config.split_indices[nano_batch_idx] + end_token_idx = split_config.split_indices[nano_batch_idx + 1] # Gather per-request info for this group nano_batch_num_scheduled_tokens = np.array( @@ -95,7 +97,7 @@ def prepare_nano_split_and_set_hooks( ): blk_table = input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor()[start_req_idx:end_req_idx] - slot_mapping = blk_table.slot_mapping[:nano_batch_total_tokens] + slot_mapping = blk_table.slot_mapping[start_token_idx:end_token_idx] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc[ @@ -111,7 +113,7 @@ def prepare_nano_split_and_set_hooks( num_computed_tokens_cpu=input_batch.num_computed_tokens_cpu_tensor[ start_req_idx:end_req_idx ], - num_reqs=nano_batch_size, + num_reqs=split_config.batch_sizes[nano_batch_idx], num_actual_tokens=nano_batch_total_tokens, max_query_len=int(max(nano_batch_num_scheduled_tokens)), block_table_tensor=blk_table_tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3d430f59d53a..5778adf18fdb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -222,7 +222,7 @@ def __init__( == CompilationLevel.PIECEWISE and self.vllm_config.compilation_config.use_cudagraph and not self.model_config.enforce_eager - and not self.model_config.enable_nano_split) + and not self.model_config.enable_nano_batch_split) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -1417,7 +1417,7 @@ def execute_model( skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) - if self.vllm_config.model_config.enable_nano_split: + if self.vllm_config.model_config.enable_nano_batch_split: self._prepare_nano_split(scheduler_output) model_output = self.model( From 651792046d9235aaa727e4b7a0d820e134d2f1cd Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Mon, 25 Aug 2025 12:55:18 -0700 Subject: [PATCH 10/22] update Signed-off-by: Yi Pan --- vllm/compilation/nanoflow/manager.py | 2 +- vllm/compilation/nanoflow/split_utils.py | 50 +++++++++++------------- vllm/config.py | 6 +++ vllm/utils/nano_split.py | 18 ++++----- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index 00dd0246b25c..2ac6543c4038 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -101,7 +101,7 @@ def op_wrapper( def get_callable(self) -> Callable: def _forward(*args, **kwargs): - if self.cached_config is None: + if self.cached_config is None or self.cached_config.num_nano_batches == 1: return self.original_graph_module(*args, **kwargs) num_nano_batches = self.cached_config.num_nano_batches diff --git a/vllm/compilation/nanoflow/split_utils.py b/vllm/compilation/nanoflow/split_utils.py index 1c7cec1550a4..cdf8071e3d7b 100644 --- a/vllm/compilation/nanoflow/split_utils.py +++ b/vllm/compilation/nanoflow/split_utils.py @@ -41,37 +41,33 @@ def get_split_config( max_num_nano_batches: int, min_nano_split_tokens: int, ) -> NanoSplitConfig: - token_batch_size = sum(num_tokens) - num_nano_batches = min( - batch_size, - (token_batch_size + min_nano_split_tokens - 1) // min_nano_split_tokens, - max_num_nano_batches, - ) - nano_batch_split_indices = [0] - nano_batch_indices = [0] - nano_batch_sizes = [] - nano_batch_num_tokens = [] + num_nano_batches = 0 + nano_batch_token_indices = [0] + nano_batch_req_indices = [0] + nano_batch_req_sizes = [] + nano_batch_token_sizes = [] prefix_sum = [0] + list(itertools.accumulate(num_tokens)) - remaining_batches = num_nano_batches - len(nano_batch_split_indices) + 1 - remaining_tokens = token_batch_size - while remaining_batches > 0: - next_index = min( - range(nano_batch_indices[-1], batch_size - remaining_batches + 1), - key=lambda x: abs(prefix_sum[x + 1] - prefix_sum[nano_batch_indices[-1]] - remaining_tokens / remaining_batches) - ) - nano_batch_indices.append(next_index + 1) - nano_batch_split_indices.append(prefix_sum[next_index + 1]) - nano_batch_sizes.append(nano_batch_indices[-1] - nano_batch_indices[-2]) - nano_batch_num_tokens.append(nano_batch_split_indices[-1] - nano_batch_split_indices[-2]) - remaining_tokens = token_batch_size - prefix_sum[next_index + 1] - remaining_batches -= 1 + # Find the mid point of the tokens + mid = min(range(len(prefix_sum)), key=lambda i: abs(prefix_sum[i] - (prefix_sum[-1] - prefix_sum[i]))) + if prefix_sum[mid] < min_nano_split_tokens or (prefix_sum[-1] - prefix_sum[mid]) < min_nano_split_tokens: + num_nano_batches = 1 + nano_batch_req_indices.append(batch_size) + nano_batch_token_indices.append(prefix_sum[-1]) + nano_batch_req_sizes.append(batch_size) + nano_batch_token_sizes.append(prefix_sum[-1]) + else: + num_nano_batches = 2 + nano_batch_req_indices.extend([mid, batch_size]) + nano_batch_token_indices.extend([prefix_sum[mid], prefix_sum[-1]]) + nano_batch_req_sizes.extend([mid, batch_size - mid]) + nano_batch_token_sizes.extend([prefix_sum[mid], prefix_sum[-1] - prefix_sum[mid]]) return NanoSplitConfig( num_nano_batches=num_nano_batches, - batch_sizes=nano_batch_sizes, - batch_indices=nano_batch_indices, - num_tokens=nano_batch_num_tokens, - split_indices=nano_batch_split_indices, + batch_sizes=nano_batch_req_sizes, + batch_indices=nano_batch_req_indices, + num_tokens=nano_batch_token_sizes, + split_indices=nano_batch_token_indices, ) diff --git a/vllm/config.py b/vllm/config.py index dd2dad522b14..f07a0f9d3da7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4649,6 +4649,12 @@ def __post_init__(self): logger.info("full_cuda_graph is not supported with " "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True + + if self.compilation_config.full_cuda_graph and \ + self.model_config.enable_nano_batch_split: + logger.info("full_cuda_graph is not supported with " + "nano batch split. Disabling nano batch split.") + self.model_config.enable_nano_batch_split = False disable_chunked_prefill_reasons: list[str] = [] diff --git a/vllm/utils/nano_split.py b/vllm/utils/nano_split.py index fc8986423e07..755542244604 100644 --- a/vllm/utils/nano_split.py +++ b/vllm/utils/nano_split.py @@ -42,14 +42,14 @@ def prepare_nano_split_and_set_hooks( tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] num_scheduled_tokens = torch.tensor(tokens, dtype=torch.int32) cached_seqlens = input_batch.num_computed_tokens_cpu[:batch_size].tolist() - split_config = nano_manager.prepare_nano_split( - batch_size, tokens, cached_seqlens - ) + split_config = nano_manager.prepare_nano_split(batch_size, tokens, cached_seqlens) + if split_config.num_nano_batches == 1: + return cu_num_tokens = torch.cumsum(num_scheduled_tokens, dim=0, dtype=torch.int32) query_start_loc_cpu = torch.zeros(batch_size + 1, dtype=torch.int32) query_start_loc_cpu[0] = 0 - query_start_loc_cpu[1:batch_size + 1] = cu_num_tokens + query_start_loc_cpu[1 : batch_size + 1] = cu_num_tokens seq_lens_cpu = torch.zeros(batch_size, dtype=torch.int32) seq_lens_cpu[:batch_size] = ( input_batch.num_computed_tokens_cpu_tensor[:batch_size] + num_scheduled_tokens @@ -100,13 +100,9 @@ def prepare_nano_split_and_set_hooks( slot_mapping = blk_table.slot_mapping[start_token_idx:end_token_idx] common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc[ - start_req_idx : end_req_idx + 1 - ] + query_start_loc=query_start_loc[start_req_idx : end_req_idx + 1] - query_start_loc[start_req_idx], - query_start_loc_cpu=query_start_loc_cpu[ - start_req_idx : end_req_idx + 1 - ] + query_start_loc_cpu=query_start_loc_cpu[start_req_idx : end_req_idx + 1] - query_start_loc_cpu[start_req_idx], seq_lens=seq_lens[start_req_idx:end_req_idx], seq_lens_cpu=seq_lens_cpu[start_req_idx:end_req_idx], @@ -142,7 +138,7 @@ def prepare_nano_split_and_set_hooks( attn_metadata=attn_metadata, virtual_engine=prev_forward_context.virtual_engine, dp_metadata=prev_forward_context.dp_metadata, - skip_cuda_graphs=prev_forward_context.skip_cuda_graphs, + skip_cuda_graphs=True, ) for attn_metadata in attn_metadatas ] From d638a0a1b8e8bab8214b61deb271a3df63cc8efd Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Mon, 25 Aug 2025 13:45:01 -0700 Subject: [PATCH 11/22] clean the impl Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 12 ++---------- vllm/config.py | 22 ++++++++++++++++------ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index b3eda946458a..e46de50ab8d1 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -549,12 +549,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.graph = graph self.configure_post_pass() - if self.vllm_config.model_config.enable_nano_batch_split: - self.split_gm, self.piecewise_graphs = split_graph(graph, [ - "vllm.all_reduce" - ]) - else: - self.split_gm, self.piecewise_graphs = split_graph(graph, self.compilation_config.splitting_ops) + self.split_gm, self.piecewise_graphs = split_graph(graph, self.compilation_config.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code @@ -593,10 +588,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: if self.vllm_config.model_config.enable_nano_batch_split: - return nano_manager.get_callable( - self.split_gm, - vllm_config=self.vllm_config - ) + return nano_manager.get_callable(self.split_gm, self.vllm_config) else: return self.split_gm diff --git a/vllm/config.py b/vllm/config.py index f07a0f9d3da7..77a41bf0d179 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4611,6 +4611,22 @@ def __post_init__(self): "To workaround this limitation, vLLM will set 'ieee' input " "precision for chunked prefill triton kernels.") + if self.model_config.enable_nano_batch_split: + if self.model_config.enforce_eager: + logger.info("nano batch split is not supported with " + "enforce_eager. Disabling nano batch split.") + self.model_config.enable_nano_batch_split = False + elif self.compilation_config.full_cuda_graph: + logger.info("full_cuda_graph is not supported with " + "nano batch split. Disabling nano batch split.") + self.model_config.enable_nano_batch_split = False + elif self.compilation_config.splitting_ops: + logger.info("splitting_ops is not supported with " + "nano batch split. Disabling nano batch split.") + self.model_config.enable_nano_batch_split = False + else: + self.compilation_config.splitting_ops = ["vllm.all_reduce"] + # async tp is built on top of sequence parallelism # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: @@ -4649,12 +4665,6 @@ def __post_init__(self): logger.info("full_cuda_graph is not supported with " "cascade attention. Disabling cascade attention.") self.model_config.disable_cascade_attn = True - - if self.compilation_config.full_cuda_graph and \ - self.model_config.enable_nano_batch_split: - logger.info("full_cuda_graph is not supported with " - "nano batch split. Disabling nano batch split.") - self.model_config.enable_nano_batch_split = False disable_chunked_prefill_reasons: list[str] = [] From 7bf426adb0a4e990ca83842dca7ff9e608d8fc7b Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Mon, 25 Aug 2025 13:59:24 -0700 Subject: [PATCH 12/22] fix config Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 3 ++- vllm/config.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index e46de50ab8d1..4183812bb7b2 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -549,7 +549,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.graph = graph self.configure_post_pass() - self.split_gm, self.piecewise_graphs = split_graph(graph, self.compilation_config.splitting_ops) + self.split_gm, self.piecewise_graphs = split_graph( + graph, self.compilation_config.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code diff --git a/vllm/config.py b/vllm/config.py index 77a41bf0d179..c51a83a59e6c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4616,6 +4616,10 @@ def __post_init__(self): logger.info("nano batch split is not supported with " "enforce_eager. Disabling nano batch split.") self.model_config.enable_nano_batch_split = False + elif self.compilation_config.use_cudagraph: + logger.info("nano batch split is currently not supported with " + "cudagraph. Disabling nano batch split.") + self.model_config.enable_nano_batch_split = False elif self.compilation_config.full_cuda_graph: logger.info("full_cuda_graph is not supported with " "nano batch split. Disabling nano batch split.") From a575716d342a702486cead3c22a4c5082e573d52 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Mon, 25 Aug 2025 14:00:22 -0700 Subject: [PATCH 13/22] update min split tokens Signed-off-by: Yi Pan --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index c51a83a59e6c..62cf34dbb8b9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -419,7 +419,7 @@ class ModelConfig: """Enable spliting the input batch into nano-batches for intra-device parallelism""" max_num_nano_batches: int = 2 """Maximum number of nano-batches to split the input batch into""" - min_nano_split_tokens: int = 512 + min_nano_split_tokens: int = 1024 """Minimum number of tokens to split the input batch""" def compute_hash(self) -> str: From 6453b8c7687e396a3698a20362c00fcff74fd6a0 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Mon, 25 Aug 2025 14:32:01 -0700 Subject: [PATCH 14/22] format Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 5 +- vllm/compilation/nanoflow/manager.py | 112 ++++++++++++++--------- vllm/compilation/nanoflow/split_utils.py | 99 ++++++++++---------- vllm/config.py | 3 +- vllm/utils/nano_split.py | 85 ++++++++--------- vllm/v1/worker/gpu_model_runner.py | 5 +- 6 files changed, 171 insertions(+), 138 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4183812bb7b2..d0a64536a996 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -14,8 +14,8 @@ import torch.fx as fx from torch._dispatch.python import enable_python_dispatcher -from vllm.compilation.nanoflow import manager as nano_manager import vllm.envs as envs +from vllm.compilation.nanoflow import manager as nano_manager from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -589,7 +589,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: if self.vllm_config.model_config.enable_nano_batch_split: - return nano_manager.get_callable(self.split_gm, self.vllm_config) + return nano_manager.get_callable(self.split_gm, + self.vllm_config) else: return self.split_gm diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index 2ac6543c4038..2e9ee3386231 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -1,32 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import contextlib import copy -import torch -from typing import Callable, ContextManager, List, Optional +from typing import Callable +import torch import torch.fx.graph_module -from vllm.compilation.nanoflow.split_utils import ( - analyze_graph, - split_graph, - NanoOpInfo, - NanoSplitConfig, - FakeModule, - get_split_config, - tag_graph, -) +from vllm.compilation.nanoflow.split_utils import (FakeModule, NanoOpInfo, + NanoSplitConfig, + analyze_graph, + get_split_config, + split_graph, tag_graph) from vllm.config import VllmConfig class NanoSplitManager: + def __init__( - self, graph_module: torch.fx.GraphModule, vllm_config: VllmConfig, + self, + graph_module: torch.fx.GraphModule, + vllm_config: VllmConfig, ) -> None: self.original_graph_module = graph_module self.original_graph = graph_module.graph # Nano split preparation - self.min_nano_split_tokens = vllm_config.model_config.min_nano_split_tokens - self.max_num_nano_batches = vllm_config.model_config.max_num_nano_batches + self.min_nano_split_tokens = \ + vllm_config.model_config.min_nano_split_tokens + self.max_num_nano_batches = \ + vllm_config.model_config.max_num_nano_batches # Initialize the base graph tag_graph( self.original_graph_module, @@ -39,10 +43,11 @@ def __init__( self.graph_modules = {1: self.original_graph_module} # Runtime preparation - self.cached_config: Optional[NanoSplitConfig] = None + self.cached_config: NanoSplitConfig | None = None self.comm_stream = torch.cuda.Stream() self.comp_stream = torch.cuda.Stream() - self.hook: Optional[Callable[[NanoOpInfo], ContextManager[None]]] = None + self.hook: Callable[[NanoOpInfo], contextlib. + AbstractContextManager[None]] | None = None self.get_bs_fn = "get_batch_size" self.split_fn = "split_input" self.wrapper_fn = "op_wrapper" @@ -62,23 +67,23 @@ def __init__( split_fn=self.split_fn, wrapper_fn=self.wrapper_fn, ) - new_graph_module = torch.fx.GraphModule(self.original_graph_module, new_graph) + new_graph_module = torch.fx.GraphModule(self.original_graph_module, + new_graph) for name, _ in self.original_graph_module.named_modules(): if "." in name or name == "": continue - torch.fx.graph_module._copy_attr(self.original_graph_module, new_graph_module, name) + torch.fx.graph_module._copy_attr(self.original_graph_module, + new_graph_module, name) self.graph_modules[num_splits] = new_graph_module - @staticmethod def get_batch_size(idx: int, cached_config: NanoSplitConfig): return cached_config.num_tokens[idx] @staticmethod def split_input(x: torch.Tensor, idx: int, cached_config: NanoSplitConfig): - return x[ - cached_config.split_indices[idx] : cached_config.split_indices[idx + 1] - ] + return x[cached_config.split_indices[idx]:cached_config. + split_indices[idx + 1]] @staticmethod def op_wrapper( @@ -87,21 +92,23 @@ def op_wrapper( args: tuple, kwargs: dict, gm: torch.fx.GraphModule, - hooks: List[Callable[[NanoOpInfo], ContextManager[None]]], + hooks: list[Callable[[NanoOpInfo], + contextlib.AbstractContextManager[None]]], ): module = getattr(gm, submod_name) tag = getattr(module, "tag", "") with contextlib.ExitStack() as stack: for hook in hooks: stack.enter_context( - hook(NanoOpInfo(submod_name, tag, idx, args, kwargs)) - ) + hook(NanoOpInfo(submod_name, tag, idx, args, kwargs))) output = module(*args, **kwargs) return output def get_callable(self) -> Callable: + def _forward(*args, **kwargs): - if self.cached_config is None or self.cached_config.num_nano_batches == 1: + if (self.cached_config is None + or self.cached_config.num_nano_batches == 1): return self.original_graph_module(*args, **kwargs) num_nano_batches = self.cached_config.num_nano_batches @@ -111,30 +118,38 @@ def _forward(*args, **kwargs): @contextlib.contextmanager def set_stream(op_info: NanoOpInfo): if op_info.tag == "all_reduce": - torch.cuda.set_stream(self.comm_stream) # type: ignore - comm_finished[op_info.idx] = torch.cuda.Event() # type: ignore + torch.cuda.set_stream(self.comm_stream) + comm_finished[op_info.idx] = torch.cuda.Event() if comp_finished[op_info.idx] is not None: - comp_finished[op_info.idx].wait() # type: ignore + assert isinstance(comp_finished[op_info.idx], + torch.cuda.Event) + comp_finished[op_info.idx].wait() comp_finished[op_info.idx] = None else: - torch.cuda.set_stream(self.comp_stream) # type: ignore - comp_finished[op_info.idx] = torch.cuda.Event() # type: ignore + torch.cuda.set_stream(self.comp_stream) + comp_finished[op_info.idx] = torch.cuda.Event() if comm_finished[op_info.idx] is not None: - comm_finished[op_info.idx].wait() # type: ignore + assert isinstance(comm_finished[op_info.idx], + torch.cuda.Event) + comm_finished[op_info.idx].wait() comm_finished[op_info.idx] = None try: yield finally: if op_info.tag == "all_reduce": - comm_finished[op_info.idx].record() # type: ignore + assert isinstance(comm_finished[op_info.idx], + torch.cuda.Event) + comm_finished[op_info.idx].record() else: - comp_finished[op_info.idx].record() # type: ignore + assert isinstance(comp_finished[op_info.idx], + torch.cuda.Event) + comp_finished[op_info.idx].record() @contextlib.contextmanager def nvtx_mark(op_info: NanoOpInfo): try: with torch.cuda.nvtx.range( - f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}" + f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}" ): yield except Exception as e: @@ -160,9 +175,12 @@ def nvtx_mark(op_info: NanoOpInfo): NanoSplitManager.split_input, cached_config=self.cached_config, ) - setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, op_wrapper) - setattr(self.graph_modules[num_nano_batches], self.get_bs_fn, get_batch_size) - setattr(self.graph_modules[num_nano_batches], self.split_fn, split_input) + setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, + op_wrapper) + setattr(self.graph_modules[num_nano_batches], self.get_bs_fn, + get_batch_size) + setattr(self.graph_modules[num_nano_batches], self.split_fn, + split_input) output = self.graph_modules[num_nano_batches](*args, **kwargs) return output @@ -171,8 +189,8 @@ def nvtx_mark(op_info: NanoOpInfo): def prepare( self, batch_size: int, - num_tokens: List[int], - cached_seqlens: List[int], + num_tokens: list[int], + cached_seqlens: list[int], ) -> NanoSplitConfig: self.cached_config = get_split_config( batch_size, @@ -183,14 +201,17 @@ def prepare( ) return self.cached_config - def set_hooks(self, op_hook: Callable[[NanoOpInfo], ContextManager[None]]): + def set_hooks(self, + op_hook: Callable[[NanoOpInfo], + contextlib.AbstractContextManager[None]]): self.hook = op_hook _split_manager = None -def get_callable(graph_module: torch.fx.GraphModule, vllm_config: VllmConfig) -> Callable: +def get_callable(graph_module: torch.fx.GraphModule, + vllm_config: VllmConfig) -> Callable: global _split_manager if _split_manager is None: _split_manager = NanoSplitManager(graph_module, vllm_config) @@ -199,8 +220,8 @@ def get_callable(graph_module: torch.fx.GraphModule, vllm_config: VllmConfig) -> def prepare_nano_split( batch_size: int, - num_tokens: List[int], - cached_seqlens: List[int], + num_tokens: list[int], + cached_seqlens: list[int], ) -> NanoSplitConfig: global _split_manager if _split_manager is None: @@ -208,7 +229,8 @@ def prepare_nano_split( return _split_manager.prepare(batch_size, num_tokens, cached_seqlens) -def set_op_hook(op_hook: Callable[[NanoOpInfo], ContextManager[None]]): +def set_op_hook(op_hook: Callable[[NanoOpInfo], + contextlib.AbstractContextManager[None]]): global _split_manager if _split_manager is None: raise ValueError("Split manager not initialized") diff --git a/vllm/compilation/nanoflow/split_utils.py b/vllm/compilation/nanoflow/split_utils.py index cdf8071e3d7b..2abf91bbece8 100644 --- a/vllm/compilation/nanoflow/split_utils.py +++ b/vllm/compilation/nanoflow/split_utils.py @@ -1,7 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses import itertools +from typing import Callable + import torch -from typing import Callable, List, Tuple, Union -import dataclasses +from torch.fx.node import Argument as NodeArgument @dataclasses.dataclass @@ -17,14 +22,15 @@ class NanoOpInfo: class NanoSplitConfig: num_nano_batches: int # Request level information - batch_sizes: List[int] - batch_indices: List[int] + batch_sizes: list[int] + batch_indices: list[int] # Token level information - num_tokens: List[int] - split_indices: List[int] # start/end indices of each nano batch + num_tokens: list[int] + split_indices: list[int] # start/end indices of each nano batch class FakeModule(torch.nn.Module): + def __init__(self, fn: Callable, **kwargs): super().__init__() self.fn = fn @@ -36,8 +42,8 @@ def forward(self, *args, **kwargs): def get_split_config( batch_size: int, - num_tokens: List[int], - cached_seqlens: List[int], + num_tokens: list[int], + cached_seqlens: list[int], max_num_nano_batches: int, min_nano_split_tokens: int, ) -> NanoSplitConfig: @@ -48,8 +54,11 @@ def get_split_config( nano_batch_token_sizes = [] prefix_sum = [0] + list(itertools.accumulate(num_tokens)) # Find the mid point of the tokens - mid = min(range(len(prefix_sum)), key=lambda i: abs(prefix_sum[i] - (prefix_sum[-1] - prefix_sum[i]))) - if prefix_sum[mid] < min_nano_split_tokens or (prefix_sum[-1] - prefix_sum[mid]) < min_nano_split_tokens: + mid = min(range(len(prefix_sum)), + key=lambda i: abs(prefix_sum[i] - + (prefix_sum[-1] - prefix_sum[i]))) + if prefix_sum[mid] < min_nano_split_tokens or ( + prefix_sum[-1] - prefix_sum[mid]) < min_nano_split_tokens: num_nano_batches = 1 nano_batch_req_indices.append(batch_size) nano_batch_token_indices.append(prefix_sum[-1]) @@ -60,7 +69,8 @@ def get_split_config( nano_batch_req_indices.extend([mid, batch_size]) nano_batch_token_indices.extend([prefix_sum[mid], prefix_sum[-1]]) nano_batch_req_sizes.extend([mid, batch_size - mid]) - nano_batch_token_sizes.extend([prefix_sum[mid], prefix_sum[-1] - prefix_sum[mid]]) + nano_batch_token_sizes.extend( + [prefix_sum[mid], prefix_sum[-1] - prefix_sum[mid]]) return NanoSplitConfig( num_nano_batches=num_nano_batches, @@ -78,27 +88,21 @@ def display_graph(graph_module: torch.fx.GraphModule, name: str): def tag_graph(gm: torch.fx.GraphModule, op_tags: dict[str, str]): - submodules = [ - (name, module) - for (name, module) in gm.named_modules() - if hasattr(module, "graph") - ] + submodules = [(name, module) for (name, module) in gm.named_modules() + if hasattr(module, "graph")] for _, module in submodules: for node in module.graph.nodes: - if ( - node.op == "call_function" - and (tag := op_tags.get(str(node.target))) is not None - ): - assert ( - getattr(module, "tag", None) is None - or getattr(module, "tag") == tag - ), f"tag mismatch: {getattr(module, 'tag')} != {tag}" - setattr(module, "tag", tag) + if (node.op == "call_function" + and (tag := op_tags.get(str(node.target))) is not None): + assert (getattr(module, "tag", None) is None or module.tag + == tag), f"tag mismatch: {module.tag} != {tag}" + module.tag = tag def analyze_graph( - graph: torch.fx.Graph, batch_size: Union[int, torch.SymInt, None] = None -) -> Tuple[List[torch.fx.Node], torch.fx.Graph]: + graph: torch.fx.Graph, + batch_size: int | torch.SymInt | None = None +) -> tuple[list[torch.fx.Node], torch.fx.Graph]: weight_nodes = set() splittable_inputs = [] base_graph = torch.fx.Graph() @@ -113,7 +117,8 @@ def analyze_graph( if not isinstance(arg, torch.SymInt): raise ValueError("Batch size is not set") batch_size = arg - elif isinstance(input_tensor := node.meta["example_value"], torch.Tensor): + elif isinstance(input_tensor := node.meta["example_value"], + torch.Tensor): shape = input_tensor.shape if shape[0] == batch_size: splittable_inputs.append(node) @@ -128,23 +133,21 @@ def split_graph( graph: torch.fx.Graph, *, out: torch.fx.Graph, - splittable_inputs: List[torch.fx.Node], + splittable_inputs: list[torch.fx.Node], num_splits: int, get_bs_fn: str, split_fn: str, wrapper_fn: str, ) -> torch.fx.Graph: - mapping = {} + mapping: dict[NodeArgument, list[torch.fx.Node]] = {} nano_batch_sizes = [] # Step 1: Get nano batch sizes and split inputs for i in range(num_splits): - nano_batch_sizes.append( - out.call_module( - get_bs_fn, - args=(i,), - ) - ) + nano_batch_sizes.append(out.call_module( + get_bs_fn, + args=(i, ), + )) for node in splittable_inputs: mapping[node] = [] for i in range(num_splits): @@ -155,10 +158,12 @@ def split_graph( mapping[node].append(slice_node) # Step 2: Split computation nodes - def _transform(idx, n) -> torch.fx.Node: + def _transform(idx: int, n: NodeArgument) -> NodeArgument: if n in mapping: return mapping[n][idx] - if isinstance(getattr(n, "meta", {}).get("example_value", None), torch.SymInt): + if isinstance( + getattr(n, "meta", {}).get("example_value", None), + torch.SymInt): return nano_batch_sizes[idx] return n @@ -170,7 +175,8 @@ def _transform(idx, n) -> torch.fx.Node: if node.op == "call_module": new_args = [_transform(split_idx, arg) for arg in node.args] new_kwargs = { - k: _transform(split_idx, v) for k, v in node.kwargs.items() + k: _transform(split_idx, v) + for k, v in node.kwargs.items() } new_node = out.call_module( wrapper_fn, @@ -178,21 +184,22 @@ def _transform(idx, n) -> torch.fx.Node: ) else: new_node = out.node_copy( - node, arg_transform=lambda n: _transform(split_idx, n) - ) + node, + arg_transform=lambda n, idx=split_idx: _transform(idx, n)) splits.append(new_node) mapping[node] = splits # Step 3: Concatenate outputs output_nodes = [node for node in graph.nodes if node.op == "output"] - assert len(output_nodes) == 1, f"Expected 1 output node, found {len(output_nodes)}" + assert len(output_nodes + ) == 1, f"Expected 1 output node, found {len(output_nodes)}" output_node = output_nodes[0] if not output_node.args: raise ValueError("Output node has no arguments") original_outputs = output_node.args[0] is_tuple = isinstance(original_outputs, tuple) if not isinstance(original_outputs, tuple): - original_outputs = (original_outputs,) + original_outputs = (original_outputs, ) new_outputs = [] for original_output in original_outputs: @@ -208,14 +215,14 @@ def _transform(idx, n) -> torch.fx.Node: # Create concatenation node concat_node = out.call_function( torch.cat, - args=(split_outputs, 0), # Concatenate along first dimension + args=(split_outputs, + 0), # Concatenate along first dimension ) new_outputs.append(concat_node) else: raise ValueError( - f"Original output {original_output} not found in node_splits" - ) + f"Original output {original_output} not found in node_splits") out.output(tuple(new_outputs) if is_tuple else new_outputs[0]) return out diff --git a/vllm/config.py b/vllm/config.py index 62cf34dbb8b9..e39e2959fd0e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -416,7 +416,8 @@ class ModelConfig: override_attention_dtype: Optional[str] = None """Override dtype for attention""" enable_nano_batch_split: bool = False - """Enable spliting the input batch into nano-batches for intra-device parallelism""" + """Enable splitting the input batch into nano-batches for intra-device + parallelism""" max_num_nano_batches: int = 2 """Maximum number of nano-batches to split the input batch into""" min_nano_split_tokens: int = 1024 diff --git a/vllm/utils/nano_split.py b/vllm/utils/nano_split.py index 755542244604..a74706c6013e 100644 --- a/vllm/utils/nano_split.py +++ b/vllm/utils/nano_split.py @@ -1,22 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from contextlib import contextmanager from typing import Optional import numpy as np import torch + +from vllm.compilation.nanoflow import manager as nano_manager from vllm.compilation.nanoflow.split_utils import NanoOpInfo -from vllm.forward_context import ( - ForwardContext, - get_forward_context, - override_forward_context, -) -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, -) +from vllm.forward_context import (ForwardContext, get_forward_context, + override_forward_context) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu_input_batch import InputBatch -from vllm.compilation.nanoflow import manager as nano_manager def _get_cumsum_and_arange( @@ -42,18 +41,21 @@ def prepare_nano_split_and_set_hooks( tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] num_scheduled_tokens = torch.tensor(tokens, dtype=torch.int32) cached_seqlens = input_batch.num_computed_tokens_cpu[:batch_size].tolist() - split_config = nano_manager.prepare_nano_split(batch_size, tokens, cached_seqlens) + split_config = nano_manager.prepare_nano_split(batch_size, tokens, + cached_seqlens) if split_config.num_nano_batches == 1: return - cu_num_tokens = torch.cumsum(num_scheduled_tokens, dim=0, dtype=torch.int32) + cu_num_tokens = torch.cumsum(num_scheduled_tokens, + dim=0, + dtype=torch.int32) query_start_loc_cpu = torch.zeros(batch_size + 1, dtype=torch.int32) query_start_loc_cpu[0] = 0 - query_start_loc_cpu[1 : batch_size + 1] = cu_num_tokens + query_start_loc_cpu[1:batch_size + 1] = cu_num_tokens seq_lens_cpu = torch.zeros(batch_size, dtype=torch.int32) seq_lens_cpu[:batch_size] = ( - input_batch.num_computed_tokens_cpu_tensor[:batch_size] + num_scheduled_tokens - ) + input_batch.num_computed_tokens_cpu_tensor[:batch_size] + + num_scheduled_tokens) query_start_loc = query_start_loc_cpu.to(input_batch.device) seq_lens = seq_lens_cpu.to(input_batch.device) @@ -69,23 +71,24 @@ def prepare_nano_split_and_set_hooks( # Gather per-request info for this group nano_batch_num_scheduled_tokens = np.array( - [scheduler_output.num_scheduled_tokens[rid] for rid in nano_batch_req_ids], + [ + scheduler_output.num_scheduled_tokens[rid] + for rid in nano_batch_req_ids + ], dtype=np.int32, ) nano_batch_cu_num_tokens, nano_batch_arange = _get_cumsum_and_arange( - nano_batch_num_scheduled_tokens - ) + nano_batch_num_scheduled_tokens) nano_batch_total_tokens = int(nano_batch_cu_num_tokens[-1]) - nano_batch_req_indices = np.repeat( - np.arange(len(nano_batch_req_ids)), nano_batch_num_scheduled_tokens - ) + nano_batch_req_indices = np.repeat(np.arange(len(nano_batch_req_ids)), + nano_batch_num_scheduled_tokens) # Compute positions for this group - nano_batch_positions_np = np.empty(nano_batch_total_tokens, dtype=np.int64) + nano_batch_positions_np = np.empty(nano_batch_total_tokens, + dtype=np.int64) np.add( - input_batch.num_computed_tokens_cpu[start_req_idx:end_req_idx][ - nano_batch_req_indices - ], + input_batch.num_computed_tokens_cpu[start_req_idx:end_req_idx] + [nano_batch_req_indices], nano_batch_arange, out=nano_batch_positions_np, ) @@ -93,22 +96,23 @@ def prepare_nano_split_and_set_hooks( # Prepare attention metadata for each KV cache group nano_batch_attn_metadata = {} for kv_cache_group_id, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups - ): + kv_cache_config.kv_cache_groups): blk_table = input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[start_req_idx:end_req_idx] - slot_mapping = blk_table.slot_mapping[start_token_idx:end_token_idx] + blk_table_tensor = blk_table.get_device_tensor( + )[start_req_idx:end_req_idx] + slot_mapping = blk_table.slot_mapping[ + start_token_idx:end_token_idx] common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc[start_req_idx : end_req_idx + 1] + query_start_loc=query_start_loc[start_req_idx:end_req_idx + 1] - query_start_loc[start_req_idx], - query_start_loc_cpu=query_start_loc_cpu[start_req_idx : end_req_idx + 1] - - query_start_loc_cpu[start_req_idx], + query_start_loc_cpu=query_start_loc_cpu[ + start_req_idx:end_req_idx + 1] - + query_start_loc_cpu[start_req_idx], seq_lens=seq_lens[start_req_idx:end_req_idx], seq_lens_cpu=seq_lens_cpu[start_req_idx:end_req_idx], - num_computed_tokens_cpu=input_batch.num_computed_tokens_cpu_tensor[ - start_req_idx:end_req_idx - ], + num_computed_tokens_cpu=input_batch. + num_computed_tokens_cpu_tensor[start_req_idx:end_req_idx], num_reqs=split_config.batch_sizes[nano_batch_idx], num_actual_tokens=nano_batch_total_tokens, max_query_len=int(max(nano_batch_num_scheduled_tokens)), @@ -116,7 +120,8 @@ def prepare_nano_split_and_set_hooks( slot_mapping=slot_mapping, ) - # NOTE(yi): does not support chunked local attention or cascade attention + # NOTE(yi): does not support chunked local attention or cascade + # attention common_prefix_len = 0 builder = attn_metadata_builders[kv_cache_group_id] attn_metadata_i = builder.build( @@ -129,9 +134,8 @@ def prepare_nano_split_and_set_hooks( attn_metadatas.append(nano_batch_attn_metadata) - assert ( - end_req_idx == batch_size - ), f"invalid nano batch size: {split_config.batch_sizes}" + assert (end_req_idx == batch_size + ), f"invalid nano batch size: {split_config.batch_sizes}" forward_contexts = [ ForwardContext( no_compile_layers=prev_forward_context.no_compile_layers, @@ -139,8 +143,7 @@ def prepare_nano_split_and_set_hooks( virtual_engine=prev_forward_context.virtual_engine, dp_metadata=prev_forward_context.dp_metadata, skip_cuda_graphs=True, - ) - for attn_metadata in attn_metadatas + ) for attn_metadata in attn_metadatas ] @contextmanager diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5778adf18fdb..b0006271f997 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -791,14 +791,13 @@ def _prepare_inputs( return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata) - + def _prepare_nano_split(self, scheduler_output: "SchedulerOutput"): prepare_nano_split_and_set_hooks( scheduler_output=scheduler_output, input_batch=self.input_batch, attn_metadata_builders=self.attn_metadata_builders, - kv_cache_config=self.kv_cache_config - ) + kv_cache_config=self.kv_cache_config) def _compute_cascade_attn_prefix_len( self, From 483a727adf1cb590071038ab93b4c5a0ea44f156 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Tue, 26 Aug 2025 18:36:49 -0700 Subject: [PATCH 15/22] make mypy happy Signed-off-by: Yi Pan --- vllm/compilation/nanoflow/manager.py | 45 +++++++++++++----------- vllm/compilation/nanoflow/split_utils.py | 4 +-- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index 2e9ee3386231..c96df8a3ebf7 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -3,7 +3,7 @@ import contextlib import copy -from typing import Callable +from typing import Callable, Optional import torch import torch.fx.graph_module @@ -43,11 +43,11 @@ def __init__( self.graph_modules = {1: self.original_graph_module} # Runtime preparation - self.cached_config: NanoSplitConfig | None = None - self.comm_stream = torch.cuda.Stream() - self.comp_stream = torch.cuda.Stream() - self.hook: Callable[[NanoOpInfo], contextlib. - AbstractContextManager[None]] | None = None + self.cached_config: Optional[NanoSplitConfig] = None + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + self.comp_stream: torch.cuda.Stream = torch.cuda.Stream() + self.hook: Optional[Callable[ + [NanoOpInfo], contextlib.AbstractContextManager[None]]] = None self.get_bs_fn = "get_batch_size" self.split_fn = "split_input" self.wrapper_fn = "op_wrapper" @@ -112,8 +112,12 @@ def _forward(*args, **kwargs): return self.original_graph_module(*args, **kwargs) num_nano_batches = self.cached_config.num_nano_batches - comm_finished = [None for _ in range(num_nano_batches)] - comp_finished = [None for _ in range(num_nano_batches)] + comm_finished: list[Optional[torch.cuda.Event]] = [ + None for _ in range(num_nano_batches) + ] + comp_finished: list[Optional[torch.cuda.Event]] = [ + None for _ in range(num_nano_batches) + ] @contextlib.contextmanager def set_stream(op_info: NanoOpInfo): @@ -121,29 +125,30 @@ def set_stream(op_info: NanoOpInfo): torch.cuda.set_stream(self.comm_stream) comm_finished[op_info.idx] = torch.cuda.Event() if comp_finished[op_info.idx] is not None: - assert isinstance(comp_finished[op_info.idx], - torch.cuda.Event) - comp_finished[op_info.idx].wait() + # NOTE(yi): this is to make mypy happy + comp_finished_event = comp_finished[op_info.idx] + assert comp_finished_event is not None + comp_finished_event.wait() comp_finished[op_info.idx] = None else: torch.cuda.set_stream(self.comp_stream) comp_finished[op_info.idx] = torch.cuda.Event() if comm_finished[op_info.idx] is not None: - assert isinstance(comm_finished[op_info.idx], - torch.cuda.Event) - comm_finished[op_info.idx].wait() + comm_finished_event = comm_finished[op_info.idx] + assert comm_finished_event is not None + comm_finished_event.wait() comm_finished[op_info.idx] = None try: yield finally: if op_info.tag == "all_reduce": - assert isinstance(comm_finished[op_info.idx], - torch.cuda.Event) - comm_finished[op_info.idx].record() + comm_finished_event = comm_finished[op_info.idx] + assert comm_finished_event is not None + comm_finished_event.record() else: - assert isinstance(comp_finished[op_info.idx], - torch.cuda.Event) - comp_finished[op_info.idx].record() + comp_finished_event = comp_finished[op_info.idx] + assert comp_finished_event is not None + comp_finished_event.record() @contextlib.contextmanager def nvtx_mark(op_info: NanoOpInfo): diff --git a/vllm/compilation/nanoflow/split_utils.py b/vllm/compilation/nanoflow/split_utils.py index 2abf91bbece8..1a4bca80ec6b 100644 --- a/vllm/compilation/nanoflow/split_utils.py +++ b/vllm/compilation/nanoflow/split_utils.py @@ -3,7 +3,7 @@ import dataclasses import itertools -from typing import Callable +from typing import Callable, Union import torch from torch.fx.node import Argument as NodeArgument @@ -101,7 +101,7 @@ def tag_graph(gm: torch.fx.GraphModule, op_tags: dict[str, str]): def analyze_graph( graph: torch.fx.Graph, - batch_size: int | torch.SymInt | None = None + batch_size: Union[int, torch.SymInt, None] = None ) -> tuple[list[torch.fx.Node], torch.fx.Graph]: weight_nodes = set() splittable_inputs = [] From 335aab1e5474d29046592d7d0b42d3d35033148e Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Wed, 3 Sep 2025 16:17:47 -0700 Subject: [PATCH 16/22] move to compilation config Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 5 +++-- vllm/compilation/nanoflow/manager.py | 32 ++++++++++++++++++++-------- vllm/config/__init__.py | 26 ++++++++-------------- vllm/config/compilation.py | 11 ++++++++++ vllm/engine/arg_utils.py | 12 ----------- vllm/v1/worker/gpu_model_runner.py | 2 +- 6 files changed, 47 insertions(+), 41 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 30f838681eb0..f2a73a8508eb 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -596,9 +596,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ not self.compilation_config.cudagraph_copy_inputs: - if self.vllm_config.model_config.enable_nano_batch_split: + if self.compilation_config.enable_nano_batch_split: return nano_manager.get_callable(self.split_gm, - self.vllm_config) + self.compilation_config, + local_cache_dir) else: return self.split_gm diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index c96df8a3ebf7..120e6ebe82f0 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -3,6 +3,7 @@ import contextlib import copy +import os from typing import Callable, Optional import torch @@ -13,7 +14,7 @@ analyze_graph, get_split_config, split_graph, tag_graph) -from vllm.config import VllmConfig +from vllm.config import CompilationConfig class NanoSplitManager: @@ -21,16 +22,15 @@ class NanoSplitManager: def __init__( self, graph_module: torch.fx.GraphModule, - vllm_config: VllmConfig, + compilation_config: CompilationConfig, + local_cache_dir: Optional[str], ) -> None: self.original_graph_module = graph_module self.original_graph = graph_module.graph # Nano split preparation - self.min_nano_split_tokens = \ - vllm_config.model_config.min_nano_split_tokens - self.max_num_nano_batches = \ - vllm_config.model_config.max_num_nano_batches + self.min_nano_split_tokens = compilation_config.min_nano_split_tokens + self.max_num_nano_batches = compilation_config.max_num_nano_batches # Initialize the base graph tag_graph( self.original_graph_module, @@ -75,6 +75,16 @@ def __init__( torch.fx.graph_module._copy_attr(self.original_graph_module, new_graph_module, name) self.graph_modules[num_splits] = new_graph_module + if local_cache_dir is not None: + graph_path = os.path.join(local_cache_dir, + f"nano_split_{num_splits}.py") + if not os.path.exists(graph_path): + src = ( + "from __future__ import annotations\nimport torch\n" + + new_graph_module.print_readable(print_output=False)) + src = src.replace("", "GraphModule") + with open(graph_path, "w") as f: + f.write(src) @staticmethod def get_batch_size(idx: int, cached_config: NanoSplitConfig): @@ -215,11 +225,15 @@ def set_hooks(self, _split_manager = None -def get_callable(graph_module: torch.fx.GraphModule, - vllm_config: VllmConfig) -> Callable: +def get_callable( + graph_module: torch.fx.GraphModule, + compilation_config: CompilationConfig, + local_cache_dir: Optional[str] = None, +) -> Callable: global _split_manager if _split_manager is None: - _split_manager = NanoSplitManager(graph_module, vllm_config) + _split_manager = NanoSplitManager(graph_module, compilation_config, + local_cache_dir) return _split_manager.get_callable() diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 2d38e417b568..4f05d15ffec8 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -503,13 +503,6 @@ class ModelConfig: definitions""" io_processor_plugin: Optional[str] = None """IOProcessor plugin name to load at model startup""" - enable_nano_batch_split: bool = False - """Enable splitting the input batch into nano-batches for intra-device - parallelism""" - max_num_nano_batches: int = 2 - """Maximum number of nano-batches to split the input batch into""" - min_nano_split_tokens: int = 1024 - """Minimum number of tokens to split the input batch""" def compute_hash(self) -> str: """ @@ -538,9 +531,6 @@ def compute_hash(self) -> str: factors.append(self.override_generation_config) factors.append(self.rope_scaling) factors.append(self.rope_theta) - factors.append(self.enable_nano_batch_split) - factors.append(self.max_num_nano_batches) - factors.append(self.min_nano_split_tokens) # hf_config can control how the model looks! factors.append(self.hf_config.to_json_string()) str_factors = str(factors) @@ -3603,25 +3593,27 @@ def __post_init__(self): "To workaround this limitation, vLLM will set 'ieee' input " "precision for chunked prefill triton kernels.") - if self.model_config.enable_nano_batch_split: + if self.compilation_config.enable_nano_batch_split: if self.model_config.enforce_eager: logger.info("nano batch split is not supported with " "enforce_eager. Disabling nano batch split.") - self.model_config.enable_nano_batch_split = False - elif self.compilation_config.use_cudagraph: + self.compilation_config.enable_nano_batch_split = False + elif self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: logger.info("nano batch split is currently not supported with " "cudagraph. Disabling nano batch split.") - self.model_config.enable_nano_batch_split = False + self.compilation_config.enable_nano_batch_split = False elif self.compilation_config.full_cuda_graph: logger.info("full_cuda_graph is not supported with " "nano batch split. Disabling nano batch split.") - self.model_config.enable_nano_batch_split = False + self.compilation_config.enable_nano_batch_split = False elif self.compilation_config.splitting_ops: logger.info("splitting_ops is not supported with " "nano batch split. Disabling nano batch split.") - self.model_config.enable_nano_batch_split = False + self.compilation_config.enable_nano_batch_split = False else: - self.compilation_config.splitting_ops = ["vllm.all_reduce"] + self.compilation_config.splitting_ops = [ + "vllm.all_reduce", + ] # If the user does not explicitly set a compilation level, then # we use the default level. The default level depends on other # settings (see the below code). diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 5c3b22001636..f67530136480 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -299,6 +299,14 @@ class CompilationConfig: minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. """ + enable_nano_batch_split: bool = False + """Enable splitting the input batch into nano-batches for intra-device + parallelism""" + max_num_nano_batches: int = 2 + """Maximum number of nano-batches to split the input batch into""" + min_nano_split_tokens: int = 1024 + """Minimum number of tokens to split the input batch""" + pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -363,6 +371,9 @@ def compute_hash(self) -> str: factors.append(self.inductor_compile_config) factors.append(self.inductor_passes) factors.append(self.pass_config.uuid()) + factors.append(self.enable_nano_batch_split) + factors.append(self.max_num_nano_batches) + factors.append(self.min_nano_split_tokens) return hashlib.sha256(str(factors).encode()).hexdigest() def __repr__(self) -> str: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8a7e9c584248..dedcf5f000f4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -435,9 +435,6 @@ class EngineArgs: get_field(ModelConfig, "override_generation_config") model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype - enable_nano_batch_split: bool = ModelConfig.enable_nano_batch_split - max_num_nano_batches: int = ModelConfig.max_num_nano_batches - min_nano_split_tokens: int = ModelConfig.min_nano_split_tokens calculate_kv_scales: bool = CacheConfig.calculate_kv_scales mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype @@ -583,12 +580,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["logits_processors"]) model_group.add_argument("--io-processor-plugin", **model_kwargs["io_processor_plugin"]) - model_group.add_argument("--enable-nano-batch-split", - **model_kwargs["enable_nano_batch_split"]) - model_group.add_argument("--max-num-nano-batches", - **model_kwargs["max_num_nano_batches"]) - model_group.add_argument("--min-nano-split-tokens", - **model_kwargs["min_nano_split_tokens"]) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) load_group = parser.add_argument_group( @@ -1005,9 +996,6 @@ def create_model_config(self) -> ModelConfig: override_attention_dtype=self.override_attention_dtype, logits_processors=self.logits_processors, io_processor_plugin=self.io_processor_plugin, - enable_nano_batch_split=self.enable_nano_batch_split, - max_num_nano_batches=self.max_num_nano_batches, - min_nano_split_tokens=self.min_nano_split_tokens, ) def validate_tensorizer_args(self): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ab1582ce141f..6daf1223a281 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1584,7 +1584,7 @@ def execute_model( batch_descriptor=batch_descriptor, ), self.maybe_get_kv_connector_output( scheduler_output) as kv_connector_output: - if self.vllm_config.model_config.enable_nano_batch_split: + if self.vllm_config.compilation_config.enable_nano_batch_split: self._prepare_nano_split(scheduler_output) model_output = self.model( From 563fe720a5f4579d85e527a7ed98d360c67736b9 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Sun, 21 Sep 2025 15:00:43 -0700 Subject: [PATCH 17/22] minor Signed-off-by: Yi Pan --- vllm/utils/nano_split.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/utils/nano_split.py b/vllm/utils/nano_split.py index 22d518a528d9..7ad1c9b9d4c6 100644 --- a/vllm/utils/nano_split.py +++ b/vllm/utils/nano_split.py @@ -142,11 +142,11 @@ def prepare_nano_split_and_set_hooks( forward_contexts = [ ForwardContext( no_compile_layers=prev_forward_context.no_compile_layers, - attn_metadata=attn_metadata, + attn_metadata=attn_metadatas[i], virtual_engine=prev_forward_context.virtual_engine, dp_metadata=prev_forward_context.dp_metadata, cudagraph_runtime_mode=CUDAGraphMode.NONE, - ) for attn_metadata in attn_metadatas + ) for i in range(split_config.num_nano_batches) ] @contextmanager From ba26309403003ae1b1d501937a9d77dcdcdb8474 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Thu, 2 Oct 2025 00:57:18 -0700 Subject: [PATCH 18/22] adapt to dbo design Signed-off-by: Yi Pan --- vllm/compilation/nanoflow/manager.py | 14 +- vllm/compilation/nanoflow/split_utils.py | 1 - vllm/utils/nano_split.py | 161 ----------------------- vllm/v1/worker/gpu_model_runner.py | 37 +++--- vllm/v1/worker/nano_batch_split.py | 67 ++++++++++ 5 files changed, 91 insertions(+), 189 deletions(-) delete mode 100644 vllm/utils/nano_split.py create mode 100644 vllm/v1/worker/nano_batch_split.py diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index 120e6ebe82f0..7d750210b8c9 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -205,15 +205,10 @@ def prepare( self, batch_size: int, num_tokens: list[int], - cached_seqlens: list[int], ) -> NanoSplitConfig: - self.cached_config = get_split_config( - batch_size, - num_tokens, - cached_seqlens, - self.max_num_nano_batches, - self.min_nano_split_tokens, - ) + self.cached_config = get_split_config(batch_size, num_tokens, + self.max_num_nano_batches, + self.min_nano_split_tokens) return self.cached_config def set_hooks(self, @@ -240,12 +235,11 @@ def get_callable( def prepare_nano_split( batch_size: int, num_tokens: list[int], - cached_seqlens: list[int], ) -> NanoSplitConfig: global _split_manager if _split_manager is None: raise ValueError("Split manager not initialized") - return _split_manager.prepare(batch_size, num_tokens, cached_seqlens) + return _split_manager.prepare(batch_size, num_tokens) def set_op_hook(op_hook: Callable[[NanoOpInfo], diff --git a/vllm/compilation/nanoflow/split_utils.py b/vllm/compilation/nanoflow/split_utils.py index 1a4bca80ec6b..e2f84e1ca537 100644 --- a/vllm/compilation/nanoflow/split_utils.py +++ b/vllm/compilation/nanoflow/split_utils.py @@ -43,7 +43,6 @@ def forward(self, *args, **kwargs): def get_split_config( batch_size: int, num_tokens: list[int], - cached_seqlens: list[int], max_num_nano_batches: int, min_nano_split_tokens: int, ) -> NanoSplitConfig: diff --git a/vllm/utils/nano_split.py b/vllm/utils/nano_split.py deleted file mode 100644 index 7ad1c9b9d4c6..000000000000 --- a/vllm/utils/nano_split.py +++ /dev/null @@ -1,161 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from typing import Optional - -import numpy as np -import torch - -from vllm.compilation.nanoflow import manager as nano_manager -from vllm.compilation.nanoflow.split_utils import NanoOpInfo -from vllm.config.compilation import CUDAGraphMode -from vllm.forward_context import (ForwardContext, get_forward_context, - override_forward_context) -from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.worker.gpu_input_batch import InputBatch -from vllm.v1.worker.utils import AttentionGroup - - -def _get_cumsum_and_arange( - num_tokens: np.ndarray, - cumsum_dtype: Optional[np.dtype] = None, -) -> tuple[np.ndarray, np.ndarray]: - cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) - total_num_tokens = cu_num_tokens[-1] - cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) - arange = np.arange(total_num_tokens) - cumsums_offsets - return cu_num_tokens, arange - - -def prepare_nano_split_and_set_hooks( - scheduler_output: SchedulerOutput, - input_batch: InputBatch, - attn_groups: list[list[AttentionGroup]], - kv_cache_config: KVCacheConfig, -) -> None: - prev_forward_context = get_forward_context() - req_ids = input_batch.req_ids - batch_size = len(req_ids) - tokens = [scheduler_output.num_scheduled_tokens[rid] for rid in req_ids] - num_scheduled_tokens = torch.tensor(tokens, dtype=torch.int32) - cached_seqlens = input_batch.num_computed_tokens_cpu[:batch_size].tolist() - split_config = nano_manager.prepare_nano_split(batch_size, tokens, - cached_seqlens) - if split_config.num_nano_batches == 1: - return - - cu_num_tokens = torch.cumsum(num_scheduled_tokens, - dim=0, - dtype=torch.int32) - query_start_loc_cpu = torch.zeros(batch_size + 1, dtype=torch.int32) - query_start_loc_cpu[0] = 0 - query_start_loc_cpu[1:batch_size + 1] = cu_num_tokens - seq_lens_cpu = torch.zeros(batch_size, dtype=torch.int32) - seq_lens_cpu[:batch_size] = ( - input_batch.num_computed_tokens_cpu_tensor[:batch_size] + - num_scheduled_tokens) - query_start_loc = query_start_loc_cpu.to(input_batch.device) - seq_lens = seq_lens_cpu.to(input_batch.device) - - attn_metadatas = [] - start_req_idx = 0 - end_req_idx = 0 - for nano_batch_idx in range(split_config.num_nano_batches): - start_req_idx = split_config.batch_indices[nano_batch_idx] - end_req_idx = split_config.batch_indices[nano_batch_idx + 1] - nano_batch_req_ids = req_ids[start_req_idx:end_req_idx] - start_token_idx = split_config.split_indices[nano_batch_idx] - end_token_idx = split_config.split_indices[nano_batch_idx + 1] - - # Gather per-request info for this group - nano_batch_num_scheduled_tokens = np.array( - [ - scheduler_output.num_scheduled_tokens[rid] - for rid in nano_batch_req_ids - ], - dtype=np.int32, - ) - nano_batch_cu_num_tokens, nano_batch_arange = _get_cumsum_and_arange( - nano_batch_num_scheduled_tokens) - nano_batch_total_tokens = int(nano_batch_cu_num_tokens[-1]) - nano_batch_req_indices = np.repeat(np.arange(len(nano_batch_req_ids)), - nano_batch_num_scheduled_tokens) - - # Compute positions for this group - nano_batch_positions_np = np.empty(nano_batch_total_tokens, - dtype=np.int64) - np.add( - input_batch.num_computed_tokens_cpu[start_req_idx:end_req_idx] - [nano_batch_req_indices], - nano_batch_arange, - out=nano_batch_positions_np, - ) - - # Prepare attention metadata for each KV cache group - nano_batch_attn_metadata = {} - for kv_cache_group_id, kv_cache_group_spec in enumerate( - kv_cache_config.kv_cache_groups): - blk_table = input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor( - )[start_req_idx:end_req_idx] - slot_mapping = blk_table.slot_mapping[ - start_token_idx:end_token_idx] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc[start_req_idx:end_req_idx + 1] - - query_start_loc[start_req_idx], - query_start_loc_cpu=query_start_loc_cpu[ - start_req_idx:end_req_idx + 1] - - query_start_loc_cpu[start_req_idx], - seq_lens=seq_lens[start_req_idx:end_req_idx], - seq_lens_cpu=seq_lens_cpu[start_req_idx:end_req_idx], - num_computed_tokens_cpu=input_batch. - num_computed_tokens_cpu_tensor[start_req_idx:end_req_idx], - num_reqs=split_config.batch_sizes[nano_batch_idx], - num_actual_tokens=nano_batch_total_tokens, - max_query_len=int(max(nano_batch_num_scheduled_tokens)), - max_seq_len=int(seq_lens_cpu[start_req_idx:end_req_idx].max()), - block_table_tensor=blk_table_tensor, - slot_mapping=slot_mapping, - ) - - for attn_group in attn_groups[kv_cache_group_id]: - # NOTE(yi): does not support chunked local attention or cascade - # attention - common_prefix_len = 0 - builder = attn_group.metadata_builder - attn_metadata_i = builder.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - ) - - for layer_name in kv_cache_group_spec.layer_names: - nano_batch_attn_metadata[layer_name] = attn_metadata_i - - attn_metadatas.append(nano_batch_attn_metadata) - - assert (end_req_idx == batch_size - ), f"invalid nano batch size: {split_config.batch_sizes}" - forward_contexts = [ - ForwardContext( - no_compile_layers=prev_forward_context.no_compile_layers, - attn_metadata=attn_metadatas[i], - virtual_engine=prev_forward_context.virtual_engine, - dp_metadata=prev_forward_context.dp_metadata, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - ) for i in range(split_config.num_nano_batches) - ] - - @contextmanager - def op_hook(op_info: NanoOpInfo): - previous_context = get_forward_context() - override_forward_context(forward_contexts[op_info.idx]) - try: - yield - finally: - override_forward_context(previous_context) - - nano_manager.set_op_hook(op_hook) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c8f53fc1d5e..8616c9037cca 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -65,7 +65,6 @@ length_from_prompt_token_ids_or_embeds, round_up, supports_dynamo) from vllm.utils.jsontree import json_map_leaves -from vllm.utils.nano_split import prepare_nano_split_and_set_hooks from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -103,6 +102,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin +from vllm.v1.worker.nano_batch_split import nano_ubatch_split from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, ubatch_split) from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices @@ -1069,12 +1069,19 @@ def _prepare_inputs( uniform_decode = \ (max_num_scheduled_tokens == self.uniform_decode_query_len) and \ (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - ubatch_slices, num_tokens_after_padding = \ - ubatch_split(num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config) + if self.compilation_config.enable_nano_batch_split: + ubatch_slices, num_tokens_after_padding = \ + nano_ubatch_split( + num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded) + else: + ubatch_slices, num_tokens_after_padding = \ + ubatch_split(num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + uniform_decode=uniform_decode, + vllm_config=self.vllm_config) self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + @@ -1291,12 +1298,6 @@ def _prepare_inputs( max_num_scheduled_tokens, ubatch_slices, num_tokens_after_padding, use_cascade_attn) - def _prepare_nano_split(self, scheduler_output: "SchedulerOutput"): - prepare_nano_split_and_set_hooks(scheduler_output=scheduler_output, - input_batch=self.input_batch, - attn_groups=self.attn_groups, - kv_cache_config=self.kv_cache_config) - def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, @@ -2000,11 +2001,11 @@ def _preprocess( Optional[IntermediateTensors], dict[str, Any]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if ubatch_slices: + if ubatch_slices and self.parallel_config.enable_dbo: assert num_tokens_after_padding is not None num_input_tokens = int(num_tokens_after_padding[0].item() * 2) self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) - elif ubatch_slices is None: + else: num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) num_pad, num_tokens_after_padding = self.get_dp_padding( num_input_tokens) @@ -2369,7 +2370,7 @@ def execute_model( # This is currently to get around the assert in the DPMetadata # where it wants `num_tokens_across_dp` to align with `num_tokens` - if ubatch_slices is not None: + if ubatch_slices is not None and self.parallel_config.enable_dbo: num_input_tokens = ubatch_slices[0].num_tokens # Run the model. @@ -3674,7 +3675,9 @@ def create_attn_groups( self.vllm_config, self.device, num_metadata_builders=1 - if not self.parallel_config.enable_dbo else 2, + if not self.parallel_config.enable_dbo + and not self.compilation_config.enable_nano_batch_split + else 2, ) attn_groups.append(attn_group) diff --git a/vllm/v1/worker/nano_batch_split.py b/vllm/v1/worker/nano_batch_split.py new file mode 100644 index 000000000000..34756febcc42 --- /dev/null +++ b/vllm/v1/worker/nano_batch_split.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager +from typing import Optional + +import numpy as np +import torch + +from vllm.compilation.nanoflow import manager as nano_manager +from vllm.compilation.nanoflow.split_utils import NanoOpInfo +from vllm.forward_context import get_forward_context +from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices + + +def nano_ubatch_split( + num_scheduled_tokens_per_request: np.ndarray, + num_tokens_unpadded: int, + num_tokens_padded: int, +) -> tuple[Optional[UBatchSlices], Optional[torch.Tensor]]: + """ + Prepare two UBatch-compatible nano-batch slices. + + - Uses nano_manager.prepare_nano_split to decide if splitting is beneficial + (i.e., num_nano_batches > 1). + - Computes a single token split point using custom logic to remain + compatible with UBatch execution. + """ + assert num_tokens_unpadded == num_tokens_padded + batch_size = int(len(num_scheduled_tokens_per_request)) + if batch_size == 0: + return (None, None) + + total_tokens = int(np.sum(num_scheduled_tokens_per_request)) + if total_tokens <= 1: + return (None, None) + + tokens_list = num_scheduled_tokens_per_request.tolist() + split_config = nano_manager.prepare_nano_split(batch_size, tokens_list) + if getattr(split_config, "num_nano_batches", 1) <= 1: + return (None, None) + assert split_config.num_nano_batches == 2 + + first_slice = UBatchSlice(slice(0, split_config.batch_indices[1]), + slice(0, split_config.split_indices[1])) + second_slice = UBatchSlice( + slice(split_config.batch_indices[1], batch_size), + slice(split_config.split_indices[1], split_config.split_indices[2])) + + @contextmanager + def op_hook(op_info: NanoOpInfo): + ctx = get_forward_context() + attn_metadata_list = ctx.attn_metadata + assert isinstance(attn_metadata_list, list) + ctx.attn_metadata = attn_metadata_list[op_info.idx] + try: + yield + finally: + ctx.attn_metadata = attn_metadata_list + pass + + nano_manager.set_op_hook(op_hook) + + return ([first_slice, second_slice], + torch.tensor(split_config.num_tokens, + device="cpu", + dtype=torch.int32)) From f087b99739d1c4c62a6565f06ddf06b831420cae Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Thu, 2 Oct 2025 17:15:07 -0700 Subject: [PATCH 19/22] fix Signed-off-by: Yi Pan --- vllm/compilation/nanoflow/__init__.py | 0 vllm/config/vllm.py | 16 +++++++++++----- vllm/v1/worker/nano_batch_split.py | 5 +---- 3 files changed, 12 insertions(+), 9 deletions(-) create mode 100644 vllm/compilation/nanoflow/__init__.py diff --git a/vllm/compilation/nanoflow/__init__.py b/vllm/compilation/nanoflow/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 2fa5433fb6c1..f07e0b7e4bbe 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -307,14 +307,20 @@ def __post_init__(self): logger.info("full_cuda_graph is not supported with " "nano batch split. Disabling nano batch split.") self.compilation_config.enable_nano_batch_split = False - elif self.compilation_config.splitting_ops: - logger.info("splitting_ops is not supported with " - "nano batch split. Disabling nano batch split.") - self.compilation_config.enable_nano_batch_split = False else: - self.compilation_config.splitting_ops = [ + nano_batch_splitting_ops = [ "vllm.all_reduce", ] + if self.compilation_config.splitting_ops and \ + set(self.compilation_config.splitting_ops) \ + != set(nano_batch_splitting_ops): + logger.info( + "splitting_ops is not supported with " + "nano batch split. Disabling nano batch split.") + self.compilation_config.enable_nano_batch_split = False + else: + self.compilation_config.splitting_ops = \ + nano_batch_splitting_ops # If the user does not explicitly set a compilation level, then # we use the default level. The default level depends on other diff --git a/vllm/v1/worker/nano_batch_split.py b/vllm/v1/worker/nano_batch_split.py index 34756febcc42..b2fa65679177 100644 --- a/vllm/v1/worker/nano_batch_split.py +++ b/vllm/v1/worker/nano_batch_split.py @@ -28,11 +28,8 @@ def nano_ubatch_split( """ assert num_tokens_unpadded == num_tokens_padded batch_size = int(len(num_scheduled_tokens_per_request)) - if batch_size == 0: - return (None, None) - total_tokens = int(np.sum(num_scheduled_tokens_per_request)) - if total_tokens <= 1: + if batch_size <= 1 or total_tokens <= 1: return (None, None) tokens_list = num_scheduled_tokens_per_request.tolist() From eed29c1a5ba109e841b725dfd9439b291fac0013 Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Thu, 9 Oct 2025 19:47:22 -0700 Subject: [PATCH 20/22] fix num tokens across dp Signed-off-by: Yi Pan --- vllm/compilation/nanoflow/manager.py | 100 +++++++++++++++------------ vllm/v1/worker/nano_batch_split.py | 16 +++-- 2 files changed, 65 insertions(+), 51 deletions(-) diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index 7d750210b8c9..016d76b300d9 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -9,16 +9,19 @@ import torch import torch.fx.graph_module -from vllm.compilation.nanoflow.split_utils import (FakeModule, NanoOpInfo, - NanoSplitConfig, - analyze_graph, - get_split_config, - split_graph, tag_graph) +from vllm.compilation.nanoflow.split_utils import ( + FakeModule, + NanoOpInfo, + NanoSplitConfig, + analyze_graph, + get_split_config, + split_graph, + tag_graph, +) from vllm.config import CompilationConfig class NanoSplitManager: - def __init__( self, graph_module: torch.fx.GraphModule, @@ -46,8 +49,9 @@ def __init__( self.cached_config: Optional[NanoSplitConfig] = None self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() self.comp_stream: torch.cuda.Stream = torch.cuda.Stream() - self.hook: Optional[Callable[ - [NanoOpInfo], contextlib.AbstractContextManager[None]]] = None + self.hook: Optional[ + Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] + ] = None self.get_bs_fn = "get_batch_size" self.split_fn = "split_input" self.wrapper_fn = "op_wrapper" @@ -67,21 +71,25 @@ def __init__( split_fn=self.split_fn, wrapper_fn=self.wrapper_fn, ) - new_graph_module = torch.fx.GraphModule(self.original_graph_module, - new_graph) + new_graph_module = torch.fx.GraphModule( + self.original_graph_module, new_graph + ) for name, _ in self.original_graph_module.named_modules(): if "." in name or name == "": continue - torch.fx.graph_module._copy_attr(self.original_graph_module, - new_graph_module, name) + torch.fx.graph_module._copy_attr( + self.original_graph_module, new_graph_module, name + ) self.graph_modules[num_splits] = new_graph_module if local_cache_dir is not None: - graph_path = os.path.join(local_cache_dir, - f"nano_split_{num_splits}.py") + graph_path = os.path.join( + local_cache_dir, f"nano_split_{num_splits}.py" + ) if not os.path.exists(graph_path): src = ( - "from __future__ import annotations\nimport torch\n" + - new_graph_module.print_readable(print_output=False)) + "from __future__ import annotations\nimport torch\n" + + new_graph_module.print_readable(print_output=False) + ) src = src.replace("", "GraphModule") with open(graph_path, "w") as f: f.write(src) @@ -92,8 +100,9 @@ def get_batch_size(idx: int, cached_config: NanoSplitConfig): @staticmethod def split_input(x: torch.Tensor, idx: int, cached_config: NanoSplitConfig): - return x[cached_config.split_indices[idx]:cached_config. - split_indices[idx + 1]] + return x[ + cached_config.split_indices[idx] : cached_config.split_indices[idx + 1] + ] @staticmethod def op_wrapper( @@ -102,23 +111,21 @@ def op_wrapper( args: tuple, kwargs: dict, gm: torch.fx.GraphModule, - hooks: list[Callable[[NanoOpInfo], - contextlib.AbstractContextManager[None]]], + hooks: list[Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]]], ): module = getattr(gm, submod_name) tag = getattr(module, "tag", "") with contextlib.ExitStack() as stack: for hook in hooks: stack.enter_context( - hook(NanoOpInfo(submod_name, tag, idx, args, kwargs))) + hook(NanoOpInfo(submod_name, tag, idx, args, kwargs)) + ) output = module(*args, **kwargs) return output def get_callable(self) -> Callable: - def _forward(*args, **kwargs): - if (self.cached_config is None - or self.cached_config.num_nano_batches == 1): + if self.cached_config is None or self.cached_config.num_nano_batches == 1: return self.original_graph_module(*args, **kwargs) num_nano_batches = self.cached_config.num_nano_batches @@ -150,6 +157,8 @@ def set_stream(op_info: NanoOpInfo): comm_finished[op_info.idx] = None try: yield + except: + raise finally: if op_info.tag == "all_reduce": comm_finished_event = comm_finished[op_info.idx] @@ -164,12 +173,11 @@ def set_stream(op_info: NanoOpInfo): def nvtx_mark(op_info: NanoOpInfo): try: with torch.cuda.nvtx.range( - f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}" + f"op_{op_info.submod_name}_{op_info.tag}_{op_info.idx}" ): yield - except Exception as e: - print(f"Error in nvtx_mark: {e}") - raise e + except: + raise # Register fake modules assert self.hook is not None @@ -190,12 +198,11 @@ def nvtx_mark(op_info: NanoOpInfo): NanoSplitManager.split_input, cached_config=self.cached_config, ) - setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, - op_wrapper) - setattr(self.graph_modules[num_nano_batches], self.get_bs_fn, - get_batch_size) - setattr(self.graph_modules[num_nano_batches], self.split_fn, - split_input) + setattr(self.graph_modules[num_nano_batches], self.wrapper_fn, op_wrapper) + setattr( + self.graph_modules[num_nano_batches], self.get_bs_fn, get_batch_size + ) + setattr(self.graph_modules[num_nano_batches], self.split_fn, split_input) output = self.graph_modules[num_nano_batches](*args, **kwargs) return output @@ -206,14 +213,17 @@ def prepare( batch_size: int, num_tokens: list[int], ) -> NanoSplitConfig: - self.cached_config = get_split_config(batch_size, num_tokens, - self.max_num_nano_batches, - self.min_nano_split_tokens) + self.cached_config = get_split_config( + batch_size, + num_tokens, + self.max_num_nano_batches, + self.min_nano_split_tokens, + ) return self.cached_config - def set_hooks(self, - op_hook: Callable[[NanoOpInfo], - contextlib.AbstractContextManager[None]]): + def set_hooks( + self, op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] + ): self.hook = op_hook @@ -227,8 +237,9 @@ def get_callable( ) -> Callable: global _split_manager if _split_manager is None: - _split_manager = NanoSplitManager(graph_module, compilation_config, - local_cache_dir) + _split_manager = NanoSplitManager( + graph_module, compilation_config, local_cache_dir + ) return _split_manager.get_callable() @@ -242,8 +253,9 @@ def prepare_nano_split( return _split_manager.prepare(batch_size, num_tokens) -def set_op_hook(op_hook: Callable[[NanoOpInfo], - contextlib.AbstractContextManager[None]]): +def set_op_hook( + op_hook: Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]], +): global _split_manager if _split_manager is None: raise ValueError("Split manager not initialized") diff --git a/vllm/v1/worker/nano_batch_split.py b/vllm/v1/worker/nano_batch_split.py index b2fa65679177..7c042ddb0918 100644 --- a/vllm/v1/worker/nano_batch_split.py +++ b/vllm/v1/worker/nano_batch_split.py @@ -38,11 +38,13 @@ def nano_ubatch_split( return (None, None) assert split_config.num_nano_batches == 2 - first_slice = UBatchSlice(slice(0, split_config.batch_indices[1]), - slice(0, split_config.split_indices[1])) + first_slice = UBatchSlice( + slice(0, split_config.batch_indices[1]), slice(0, split_config.split_indices[1]) + ) second_slice = UBatchSlice( slice(split_config.batch_indices[1], batch_size), - slice(split_config.split_indices[1], split_config.split_indices[2])) + slice(split_config.split_indices[1], split_config.split_indices[2]), + ) @contextmanager def op_hook(op_info: NanoOpInfo): @@ -58,7 +60,7 @@ def op_hook(op_info: NanoOpInfo): nano_manager.set_op_hook(op_hook) - return ([first_slice, second_slice], - torch.tensor(split_config.num_tokens, - device="cpu", - dtype=torch.int32)) + return ( + [first_slice, second_slice], + torch.tensor([num_tokens_padded], device="cpu", dtype=torch.int32), + ) From d8b3573511239a6a09b4c80948842a7ca71e86dd Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Thu, 9 Oct 2025 21:39:36 -0700 Subject: [PATCH 21/22] format fix Signed-off-by: Yi Pan --- vllm/compilation/nanoflow/split_utils.py | 73 +++++++++++++----------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/vllm/compilation/nanoflow/split_utils.py b/vllm/compilation/nanoflow/split_utils.py index e2f84e1ca537..39c1da0b1b71 100644 --- a/vllm/compilation/nanoflow/split_utils.py +++ b/vllm/compilation/nanoflow/split_utils.py @@ -30,7 +30,6 @@ class NanoSplitConfig: class FakeModule(torch.nn.Module): - def __init__(self, fn: Callable, **kwargs): super().__init__() self.fn = fn @@ -53,11 +52,14 @@ def get_split_config( nano_batch_token_sizes = [] prefix_sum = [0] + list(itertools.accumulate(num_tokens)) # Find the mid point of the tokens - mid = min(range(len(prefix_sum)), - key=lambda i: abs(prefix_sum[i] - - (prefix_sum[-1] - prefix_sum[i]))) - if prefix_sum[mid] < min_nano_split_tokens or ( - prefix_sum[-1] - prefix_sum[mid]) < min_nano_split_tokens: + mid = min( + range(len(prefix_sum)), + key=lambda i: abs(prefix_sum[i] - (prefix_sum[-1] - prefix_sum[i])), + ) + if ( + prefix_sum[mid] < min_nano_split_tokens + or (prefix_sum[-1] - prefix_sum[mid]) < min_nano_split_tokens + ): num_nano_batches = 1 nano_batch_req_indices.append(batch_size) nano_batch_token_indices.append(prefix_sum[-1]) @@ -69,7 +71,8 @@ def get_split_config( nano_batch_token_indices.extend([prefix_sum[mid], prefix_sum[-1]]) nano_batch_req_sizes.extend([mid, batch_size - mid]) nano_batch_token_sizes.extend( - [prefix_sum[mid], prefix_sum[-1] - prefix_sum[mid]]) + [prefix_sum[mid], prefix_sum[-1] - prefix_sum[mid]] + ) return NanoSplitConfig( num_nano_batches=num_nano_batches, @@ -87,20 +90,25 @@ def display_graph(graph_module: torch.fx.GraphModule, name: str): def tag_graph(gm: torch.fx.GraphModule, op_tags: dict[str, str]): - submodules = [(name, module) for (name, module) in gm.named_modules() - if hasattr(module, "graph")] + submodules = [ + (name, module) + for (name, module) in gm.named_modules() + if hasattr(module, "graph") + ] for _, module in submodules: for node in module.graph.nodes: - if (node.op == "call_function" - and (tag := op_tags.get(str(node.target))) is not None): - assert (getattr(module, "tag", None) is None or module.tag - == tag), f"tag mismatch: {module.tag} != {tag}" + if ( + node.op == "call_function" + and (tag := op_tags.get(str(node.target))) is not None + ): + assert getattr(module, "tag", None) is None or module.tag == tag, ( + f"tag mismatch: {module.tag} != {tag}" + ) module.tag = tag def analyze_graph( - graph: torch.fx.Graph, - batch_size: Union[int, torch.SymInt, None] = None + graph: torch.fx.Graph, batch_size: Union[int, torch.SymInt, None] = None ) -> tuple[list[torch.fx.Node], torch.fx.Graph]: weight_nodes = set() splittable_inputs = [] @@ -116,8 +124,7 @@ def analyze_graph( if not isinstance(arg, torch.SymInt): raise ValueError("Batch size is not set") batch_size = arg - elif isinstance(input_tensor := node.meta["example_value"], - torch.Tensor): + elif isinstance(input_tensor := node.meta["example_value"], torch.Tensor): shape = input_tensor.shape if shape[0] == batch_size: splittable_inputs.append(node) @@ -143,10 +150,12 @@ def split_graph( # Step 1: Get nano batch sizes and split inputs for i in range(num_splits): - nano_batch_sizes.append(out.call_module( - get_bs_fn, - args=(i, ), - )) + nano_batch_sizes.append( + out.call_module( + get_bs_fn, + args=(i,), + ) + ) for node in splittable_inputs: mapping[node] = [] for i in range(num_splits): @@ -160,9 +169,7 @@ def split_graph( def _transform(idx: int, n: NodeArgument) -> NodeArgument: if n in mapping: return mapping[n][idx] - if isinstance( - getattr(n, "meta", {}).get("example_value", None), - torch.SymInt): + if isinstance(getattr(n, "meta", {}).get("example_value", None), torch.SymInt): return nano_batch_sizes[idx] return n @@ -174,8 +181,7 @@ def _transform(idx: int, n: NodeArgument) -> NodeArgument: if node.op == "call_module": new_args = [_transform(split_idx, arg) for arg in node.args] new_kwargs = { - k: _transform(split_idx, v) - for k, v in node.kwargs.items() + k: _transform(split_idx, v) for k, v in node.kwargs.items() } new_node = out.call_module( wrapper_fn, @@ -183,22 +189,21 @@ def _transform(idx: int, n: NodeArgument) -> NodeArgument: ) else: new_node = out.node_copy( - node, - arg_transform=lambda n, idx=split_idx: _transform(idx, n)) + node, arg_transform=lambda n, idx=split_idx: _transform(idx, n) + ) splits.append(new_node) mapping[node] = splits # Step 3: Concatenate outputs output_nodes = [node for node in graph.nodes if node.op == "output"] - assert len(output_nodes - ) == 1, f"Expected 1 output node, found {len(output_nodes)}" + assert len(output_nodes) == 1, f"Expected 1 output node, found {len(output_nodes)}" output_node = output_nodes[0] if not output_node.args: raise ValueError("Output node has no arguments") original_outputs = output_node.args[0] is_tuple = isinstance(original_outputs, tuple) if not isinstance(original_outputs, tuple): - original_outputs = (original_outputs, ) + original_outputs = (original_outputs,) new_outputs = [] for original_output in original_outputs: @@ -214,14 +219,14 @@ def _transform(idx: int, n: NodeArgument) -> NodeArgument: # Create concatenation node concat_node = out.call_function( torch.cat, - args=(split_outputs, - 0), # Concatenate along first dimension + args=(split_outputs, 0), # Concatenate along first dimension ) new_outputs.append(concat_node) else: raise ValueError( - f"Original output {original_output} not found in node_splits") + f"Original output {original_output} not found in node_splits" + ) out.output(tuple(new_outputs) if is_tuple else new_outputs[0]) return out From fc3d688b582b06edc29291ddb2a0881fd01ec7ad Mon Sep 17 00:00:00 2001 From: Yi Pan Date: Tue, 11 Nov 2025 00:28:25 -0800 Subject: [PATCH 22/22] fix splitting ops Signed-off-by: Yi Pan --- vllm/compilation/backends.py | 2 +- vllm/compilation/nanoflow/manager.py | 20 +++++++++----------- vllm/config/vllm.py | 22 ++++++++-------------- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 41cee0ec1b0a..4f29cbc01a2f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -17,11 +17,11 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import pass_context +from vllm.compilation.nanoflow import manager as nano_manager from vllm.compilation.partition_rules import ( inductor_partition_rule_context, should_split, ) -from vllm.compilation.nanoflow import manager as nano_manager from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform diff --git a/vllm/compilation/nanoflow/manager.py b/vllm/compilation/nanoflow/manager.py index 016d76b300d9..a7b80ac20476 100644 --- a/vllm/compilation/nanoflow/manager.py +++ b/vllm/compilation/nanoflow/manager.py @@ -4,7 +4,7 @@ import contextlib import copy import os -from typing import Callable, Optional +from collections.abc import Callable import torch import torch.fx.graph_module @@ -26,7 +26,7 @@ def __init__( self, graph_module: torch.fx.GraphModule, compilation_config: CompilationConfig, - local_cache_dir: Optional[str], + local_cache_dir: str | None, ) -> None: self.original_graph_module = graph_module self.original_graph = graph_module.graph @@ -38,20 +38,18 @@ def __init__( tag_graph( self.original_graph_module, { - "vllm.unified_attention": "attention", - "vllm.unified_attention_with_output": "attention", "vllm.all_reduce": "all_reduce", }, ) self.graph_modules = {1: self.original_graph_module} # Runtime preparation - self.cached_config: Optional[NanoSplitConfig] = None + self.cached_config: NanoSplitConfig | None = None self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() self.comp_stream: torch.cuda.Stream = torch.cuda.Stream() - self.hook: Optional[ - Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] - ] = None + self.hook: ( + Callable[[NanoOpInfo], contextlib.AbstractContextManager[None]] | None + ) = None self.get_bs_fn = "get_batch_size" self.split_fn = "split_input" self.wrapper_fn = "op_wrapper" @@ -129,10 +127,10 @@ def _forward(*args, **kwargs): return self.original_graph_module(*args, **kwargs) num_nano_batches = self.cached_config.num_nano_batches - comm_finished: list[Optional[torch.cuda.Event]] = [ + comm_finished: list[torch.cuda.Event | None] = [ None for _ in range(num_nano_batches) ] - comp_finished: list[Optional[torch.cuda.Event]] = [ + comp_finished: list[torch.cuda.Event | None] = [ None for _ in range(num_nano_batches) ] @@ -233,7 +231,7 @@ def set_hooks( def get_callable( graph_module: torch.fx.GraphModule, compilation_config: CompilationConfig, - local_cache_dir: Optional[str] = None, + local_cache_dir: str | None = None, ) -> Callable: global _split_manager if _split_manager is None: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b09cde3447d9..5a5c1db51dd1 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -433,20 +433,14 @@ def __post_init__(self): "nano batch split. Disabling nano batch split." ) self.compilation_config.enable_nano_batch_split = False - else: - nano_batch_splitting_ops = [ - "vllm.all_reduce", - ] - if self.compilation_config.splitting_ops and set( - self.compilation_config.splitting_ops - ) != set(nano_batch_splitting_ops): - logger.info( - "splitting_ops is not supported with " - "nano batch split. Disabling nano batch split." - ) - self.compilation_config.enable_nano_batch_split = False - else: - self.compilation_config.splitting_ops = nano_batch_splitting_ops + elif ( + self.compilation_config.splitting_ops + and "vllm.all_reduce" not in self.compilation_config.splitting_ops + ): + logger.info( + "adding vllm.all_reduce to splitting_ops for nano batch split." + ) + self.compilation_config.splitting_ops.append("vllm.all_reduce") # If the user does not explicitly set a compilation mode, then # we use the default mode. The default mode depends on other diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d95ff7efed4..4321cffc370c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1212,7 +1212,9 @@ def _prepare_inputs( # running prefills. This lets us set enforce_eager on the prefiller in # a P/D setup and still use CUDA graphs (enabled by this padding) on the # decoder. - allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + allow_dp_padding = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, @@ -4195,6 +4197,7 @@ def initialize_metadata_builders( else None, num_metadata_builders=1 if not self.parallel_config.enable_dbo + and not self.compilation_config.enable_nano_batch_split else 2, ) # Calculate reorder batch threshold (if needed)