From 567ef62bf79bbdd65696e658cb3b152fcb6e76d3 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 04:37:45 +0000 Subject: [PATCH 01/36] Adding config loading and benchmarking for fused_moe_lora shrink and expand Signed-off-by: Yu Gong --- .pre-commit-config.yaml | 10 +- benchmarks/kernels/benchmark_lora.py | 430 +++++++++++++++++- vllm/lora/layers/fused_moe.py | 100 +++- vllm/lora/ops/triton_ops/README_TUNING.md | 10 +- vllm/lora/ops/triton_ops/__init__.py | 9 +- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 373 ++++++++++++--- vllm/lora/ops/triton_ops/utils.py | 42 +- vllm/lora/punica_wrapper/punica_base.py | 3 +- vllm/lora/punica_wrapper/punica_gpu.py | 22 +- 9 files changed, 859 insertions(+), 140 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bcd40e7f8ab3..1ac1fdfd6bce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,12 +2,12 @@ default_install_hook_types: - pre-commit - commit-msg default_stages: - - pre-commit # Run locally + - commit # Run locally - manual # Run in CI exclude: 'vllm/third_party/.*' repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.0 + rev: v0.14.1 hooks: - id: ruff-check args: [--output-format, github, --fix] @@ -48,9 +48,9 @@ repos: entry: python tools/pre_commit/generate_nightly_torch_test.py files: ^requirements/test\.(in|txt)$ - id: mypy-local - name: Run mypy locally for lowest supported Python version - entry: python tools/pre_commit/mypy.py 0 "3.10" - stages: [pre-commit] # Don't run in CI + name: Run mypy for local Python installation + entry: python tools/pre_commit/mypy.py 0 "local" + stages: [commit] # Don't run in CI <<: &mypy_common language: python types_or: [python, pyi] diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index bf1512268fe0..96da0a5ea79e 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -19,13 +19,24 @@ from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.triton_utils import HAS_TRITON +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs +from vllm.triton_utils import HAS_TRITON, triton if HAS_TRITON: - from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora + LoRAKernelMeta, + fused_moe_lora_expand, + fused_moe_lora_shrink, + lora_expand, + lora_shrink, + ) + from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + _LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora + ) from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT - +from vllm import _custom_ops as ops from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.math_utils import round_up DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] @@ -191,6 +202,11 @@ class OpType(Enum): LORA_SHRINK = auto() LORA_EXPAND = auto() + ## Adding support for fused moe lora + FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink + FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand + FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink + FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand @staticmethod def from_str(s: str) -> "OpType": @@ -198,6 +214,15 @@ def from_str(s: str) -> "OpType": return OpType.LORA_SHRINK if s.lower() == "lora_expand": return OpType.LORA_EXPAND + # Adding support for fused moe lora, both in gate_up and down + if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink + return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK + if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand + return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND + if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink + return OpType.FUSED_MOE_LORA_DOWN_SHRINK + if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand + return OpType.FUSED_MOE_LORA_DOWN_EXPAND raise ValueError(f"Unrecognized str {s} to convert to OpType") def is_shrink_fn(self) -> bool: @@ -206,7 +231,45 @@ def is_shrink_fn(self) -> bool: def is_expand_fn(self) -> bool: return self in [OpType.LORA_EXPAND] + def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_gate_up_fn( + self, + ) -> bool: ## adding for fused MoE LoRA Gate/Up + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + ] + + def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down + return self in [ + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_shrink_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ] + + def is_fused_moe_lora_expand_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + def num_slices(self) -> list[int]: + if self.is_fused_moe_lora_gate_up_fn(): + return [2] + elif self.is_fused_moe_lora_down_fn(): + return [1] return [1, 2, 3] def mkn( @@ -217,11 +280,15 @@ def mkn( m = num_tokens k = hidden_size n = lora_rank - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): m = num_tokens k = lora_rank n = hidden_size + else: + assert self.is_fused_moe_lora_fn() + m = num_tokens + n = hidden_size + k = lora_rank return m, k, n def matmul_dtypes( @@ -232,9 +299,37 @@ def matmul_dtypes( """ if self.is_shrink_fn(): return op_dtype, op_dtype, torch.float32 - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): return torch.float32, op_dtype, op_dtype + else: + assert self.is_fused_moe_lora_fn() + return op_dtype, op_dtype, op_dtype + + def matmul_shapes_fused_moe_lora( + self, + m: int, + n: int, + k: int, + num_loras: int, + num_slices: int, + top_k_num: int, + num_experts: int, + ) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]: + if self.is_fused_moe_lora_gate_up_fn(): + if self.is_fused_moe_lora_shrink_fn(): + input_shape = ( + (m * top_k_num, n) + if self in [OpType.FUSED_MOE_LORA_GATE_UP_SHRINK] + else (m, n) + ) + output_shape = (num_slices, m, top_k_num, k) + weight_shape = (num_loras, num_experts, k, n) + else: + assert self.is_fused_moe_lora_expand_fn() + input_shape = (num_slices, m, top_k_num, k) + output_shape = (m, top_k_num, n * num_slices) + weight_shape = (num_loras, num_experts, n, k) + return (input_shape, weight_shape, output_shape) def matmul_shapes( self, @@ -244,6 +339,8 @@ def matmul_shapes( lora_rank: int, num_loras: int, num_slices: int, + top_k_num: int | None = None, + num_experts: int | None = None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Given num_slices, return the shapes of the A, B, and C matrices @@ -258,6 +355,17 @@ def matmul_shapes( if self in [OpType.LORA_EXPAND]: # LoRA expand kernels support num_slices inherently in the kernel return ((num_slices, m, k), b_shape, (m, n * num_slices)) + if self.is_fused_moe_lora_fn(): + return self.matmul_shapes_fused_moe_lora( + self, + m, + k, + n, + num_loras, + num_slices, + top_k_num, + num_experts, + ) raise ValueError(f"Unrecognized op_type {self}") def bench_fn(self) -> Callable: @@ -265,6 +373,16 @@ def bench_fn(self) -> Callable: return lora_shrink if self == OpType.LORA_EXPAND: return lora_expand + if ( + self == OpType.FUSED_MOE_LORA_GATE_UP_SHRINK + or self == OpType.FUSED_MOE_LORA_DOWN_SHRINK + ): + return fused_moe_lora_shrink + if ( + self == OpType.FUSED_MOE_LORA_GATE_UP_EXPAND + or self == OpType.FUSED_MOE_LORA_DOWN_EXPAND + ): + return fused_moe_lora_expand raise ValueError(f"Unrecognized optype {self}") @@ -318,6 +436,8 @@ class BenchmarkContext: sort_by_lora_id: bool dtype: torch.dtype seq_length: int | None = None + num_experts: int | None = None # num_experts for MoE based ops + top_k_num: int | None = None # top_k for MoE based ops num_slices: int | None = None # num_slices for slice based ops def with_seq_length(self, seq_length: int) -> "BenchmarkContext": @@ -385,6 +505,8 @@ def make( ctx.lora_rank, ctx.num_loras, ctx.num_slices, + ctx.top_k_num, + ctx.num_experts, ) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) input_tensor, lora_weights, output_tensor = make_rand_tensors( @@ -432,17 +554,33 @@ def make( prompt_lora_indices_tensor, ) - def sanity_check(self) -> None: + def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: """ Fails asserts when non-conformality is detected. """ - num_tokens = self.input.shape[-2] + ##TODO test if this works + num_tokens = ( + self.input.shape[1] + if op_type.is_fused_moe_lora_expand_fn() + else self.input.shape[-2] + ) # check metadata tensors - assert torch.sum(self.seq_lens) == num_tokens + ## In down shrink case, each token is repeated top_k_num times + assert ( + torch.sum(self.seq_lens) * ctx.top_k_num == num_tokens + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else torch.sum(self.seq_lens) == num_tokens + ) num_seqs = self.seq_lens.shape[0] # assert self.seq_start_loc.shape[0] == num_seqs + ## In down shrink case, each prompt corresponds to top_k_num sequences assert self.prompt_lora_mapping.shape[0] == num_seqs - assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens + assert ( + self.lora_kernel_meta.token_lora_mapping.shape[0] * ctx.top_k_num + == num_tokens + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens + ) def to_device(self, device: str): """ @@ -471,21 +609,114 @@ def to_device(tensor: torch.Tensor): to_device(field) if field_name != "no_lora_flag_cpu" else field, ) - def metadata(self) -> tuple[int, int, int]: + def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] - num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] + ## TODO: test if this works + num_tokens = ( + self.lora_kernel_meta.token_lora_mapping.shape[0] * ctx.top_k_num + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else self.lora_kernel_meta.token_lora_mapping.shape[0] + ) max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices - def as_lora_shrink_kwargs(self) -> dict[str, Any]: - self.sanity_check() + def fused_moe_lora_data_prepare( + self, + block_size: int, + token_lora_mapping: torch.Tensor, + ctx: BenchmarkContext, + ): + def moe_lora_align_block_size( + topk_ids: torch.Tensor, + token_lora_mapping: torch.Tensor, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + num_tokens = ctx.batch_size + curr_topk_ids = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + topk_weights = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + + (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = ( + moe_lora_align_block_size( + curr_topk_ids=curr_topk_ids, + token_lora_mapping=token_lora_mapping, + block_size=block_size, + num_experts=ctx.num_experts, + max_loras=ctx.num_loras, + ) + ) + + sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1) + expert_ids = expert_ids_lora.view(ctx.num_loras, -1) + num_tokens_post_padded = num_tokens_post_padded_lora + return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) + + def as_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -520,11 +751,13 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } - def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: - self.sanity_check() + def as_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -561,8 +794,154 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } + def as_fused_moe_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + # Expected input shape : [num_tokens, hidden_size] for gate_up + # Expected input shape : [top_k_num * num_tokens, hidden_size] for down + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size] + assert len(lw_shape) == 4 + assert lw_shape[-1] == hidden_size + lora_rank = lw_shape[-2] + # Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(o_shape) == 4 + assert o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank) + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "qcurr_hidden_states": self.input, + "lora_a_stacked": self.lora_weights_lst, + "a_intermediate_cache1": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "shrink_block_size_m": kernel_config["BLOCK_SIZE_M"], + "shrink_block_size_n": kernel_config["BLOCK_SIZE_N"], + "shrink_block_size_k": kernel_config["BLOCK_SIZE_K"], + "shrink_group_size_m": kernel_config["GROUP_SIZE_M"], + "shrink_num_warps": kernel_config["num_warps"], + "shrink_num_stages": kernel_config["num_stages"], + "shrink_splitK": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def as_fused_moe_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + + # Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(i_shape) == 4 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[-1] + # Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank] + assert len(lw_shape) == 4 + assert lw_shape[-1] == lora_rank + hidden_size = lw_shape[-2] + # Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices] + assert len(o_shape) == 3 + assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices) + + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "a_intermediate_cache1": self.input, + "lora_b_stacked": self.lora_weights_lst, + "output": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "max_lora_rank": lora_rank, + "w1_output_dim_size": lw_shape[2], + "expand_block_size_m": kernel_config["BLOCK_SIZE_M"], + "expand_block_size_n": kernel_config["BLOCK_SIZE_N"], + "expand_block_size_k": kernel_config["BLOCK_SIZE_K"], + "expand_group_size_m": kernel_config["GROUP_SIZE_M"], + "expand_num_warps": kernel_config["num_warps"], + "expand_num_stages": kernel_config["num_stages"], + "expand_splitK": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + def bench_fn_kwargs( - self, op_type: OpType, add_inputs: bool | None = None + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None ) -> dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None @@ -570,9 +949,13 @@ def bench_fn_kwargs( assert add_inputs is not None if op_type == OpType.LORA_SHRINK: - return self.as_lora_shrink_kwargs() + return self.as_lora_shrink_kwargs(ctx, op_type) if op_type == OpType.LORA_EXPAND: - return self.as_lora_expand_kwargs(add_inputs) + return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) + if op_type.is_fused_moe_lora_shrink_fn(): + return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) + if op_type.is_fused_moe_lora_expand_fn(): + return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) raise ValueError(f"Unrecognized optype {self}") def test_correctness( @@ -627,7 +1010,7 @@ def bench_optype( BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) ] for bt in bench_tensors: - bt.sanity_check() + bt.sanity_check(ctx, op_type) # Test correctness of our implementation. if test_correctness: @@ -644,6 +1027,7 @@ def bench_optype( # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() + _LORA_PTR_DICT.clear() # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 7711f5c3208b..1712c17af89b 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -13,6 +13,7 @@ get_tensor_model_parallel_world_size, ) from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import ( _get_config_dtype_str, @@ -90,17 +91,46 @@ def wrapper(*args, **kwargs): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, - ) - + ## if the env var is set, loading the config + if envs.VLLM_TUNED_CONFIG_FOLDER: + # get the gate/up shrink config + shrink_config = get_lora_op_configs( + op_type="fused_moe_lora_gate_up_shrink", + max_loras=self.w1_lora_a_stacked.shape[0], + batch=M, + hidden_size=self.w1_lora_a_stacked.shape[-1], + rank=self.w1_lora_a_stacked.shape[-2], + num_slices=2, + hidden_size_2=self.w1_lora_b_stacked.shape[-2], + ) + # get the gate/up expand config + expand_config = get_lora_op_configs( + op_type="fused_moe_lora_gate_up_expand", + max_loras=self.w1_lora_a_stacked.shape[0], + batch=M, + hidden_size=self.w1_lora_a_stacked.shape[-1], + rank=self.w1_lora_a_stacked.shape[-2], + num_slices=2, + hidden_size_2=self.w1_lora_b_stacked.shape[-2], + ) + else: # fall back to the default config + get_config_func = functools.partial( + try_get_optimal_moe_config, + self.w13_weight.size(), + self.w2_weight.size(), + top_k, + config_dtype, + block_shape=self.quant_method.moe_quant_config.block_shape, + ) + + shrink_config = get_config_func(M) + expand_config = get_config_func(M) ## same as the shrink config + # get the block size of m from customized config or default config max_loras = self.w1_lora_a_stacked.shape[0] - config = get_config_func(M) + block_size = ( + shrink_config.get("BLOCK_SIZE_M", shrink_config.get("block_m", 64)) + or 64 + ) ( sorted_token_ids_lora, expert_ids_lora, @@ -108,7 +138,7 @@ def wrapper(*args, **kwargs): ) = self.punica_wrapper.moe_lora_align_block_size( curr_topk_ids, num_tokens, - config["BLOCK_SIZE_M"], + block_size, self.base_layer.local_num_experts, max_loras, self.adapter_enabled, @@ -138,7 +168,8 @@ def wrapper(*args, **kwargs): num_tokens_post_padded_lora, max_lora_rank, top_k, - config, + shrink_config, ## pass the shrink config + expand_config, ## pass the expand config self.adapter_enabled, ) @@ -164,16 +195,38 @@ def wrapper(*args, **kwargs): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, - ) - - config = get_config_func(M) + if envs.VLLM_TUNED_CONFIG_FOLDER: + # get the down shrink config + shrink_config = get_lora_op_configs( + op_type="fused_moe_lora_down_shrink", + max_loras=self.w2_lora_a_stacked.shape[0], + batch=M, + hidden_size=self.w2_lora_a_stacked.shape[-1], + rank=self.w2_lora_a_stacked.shape[-2], + num_slices=1, + hidden_size_2=self.w2_lora_b_stacked.shape[-2], + ) + # get the down expand config + expand_config = get_lora_op_configs( + op_type="fused_moe_lora_down_expand", + max_loras=self.w1_lora_a_stacked.shape[0], + batch=M, + hidden_size=self.w1_lora_a_stacked.shape[-1], + rank=self.w1_lora_a_stacked.shape[-2], + num_slices=1, + hidden_size_2=self.w1_lora_b_stacked.shape[-2], + ) + else: + get_config_func = functools.partial( + try_get_optimal_moe_config, + self.w13_weight.size(), + self.w2_weight.size(), + top_k, + config_dtype, + block_shape=self.quant_method.moe_quant_config.block_shape, + ) + shrink_config = get_config_func(M) + expand_config = get_config_func(M) sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"] expert_ids_lora = moe_state_dict["expert_ids_lora"] @@ -197,7 +250,8 @@ def wrapper(*args, **kwargs): num_tokens_post_padded_lora, max_lora_rank, top_k, - config, + shrink_config, ## pass the shrink config + expand_config, ## pass the expand config self.adapter_enabled, True, ) diff --git a/vllm/lora/ops/triton_ops/README_TUNING.md b/vllm/lora/ops/triton_ops/README_TUNING.md index fda95ea71891..56c8c27051cc 100644 --- a/vllm/lora/ops/triton_ops/README_TUNING.md +++ b/vllm/lora/ops/triton_ops/README_TUNING.md @@ -44,8 +44,16 @@ For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`. +For `fused_moe_lora_gate_up_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_GATE_UP_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_GATE_UP_SHRINK.json`. + +For `fused_moe_lora_gate_up_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_GATE_UP_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_GATE_UP_EXPAND.json`. + +For `fused_moe_lora_down_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_DOWN_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_DOWN_SHRINK.json`. + +For `fused_moe_lora_down_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_DOWN_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_DOWN_EXPAND.json`. + The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()` ### Json Structure -Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]` +Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][n2]` diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index 436ea4ed00c8..7e8b9a79add3 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -1,7 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora + +from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + fused_moe_lora, + fused_moe_lora_expand, + fused_moe_lora_shrink, +) from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink @@ -11,4 +16,6 @@ "lora_shrink", "LoRAKernelMeta", "fused_moe_lora", + "fused_moe_lora_shrink", + "fused_moe_lora_expand", ] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 539605c7c534..ae59d73487b9 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -176,92 +176,54 @@ def _fused_moe_lora_kernel( @torch.inference_mode() -def _fused_moe_lora( - output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) +def _fused_moe_lora_shrink( + a_intermediate_cache1: + torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) lora_a_stacked: list[ torch.Tensor ], # [(max_loras, num_experts, max_lora_rank, K,),...] - lora_b_stacked: list[ - torch.Tensor - ], # [(max_loras, num_experts, N, max_lora_rank,),...] topk_weights: torch.Tensor, # (num_tokens, top_k_num) sorted_token_ids: torch.Tensor, # (max_loras, _) expert_ids: torch.Tensor, # (max_loras, _ ,) num_tokens_post_padded: torch.Tensor, # (max_loras, ) - max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, - block_size_m: int, - block_size_n: int, - block_size_k: int, - group_size_m: int, - split_k: int, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, mul_routed_weight: bool = False, ) -> None: - assert len(lora_a_stacked) == len(lora_b_stacked) > 0 - assert ( - sorted_token_ids.dim() - == expert_ids.dim() - == topk_weights.dim() - == qcurr_hidden_states.dim() - == 2 - ) - assert ( - sorted_token_ids.shape[0] - == expert_ids.shape[0] - == num_tokens_post_padded.shape[0] - ) - assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] - assert output.shape[0] == topk_weights.shape[0] - assert top_k_num == topk_weights.shape[1] - - for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked): - assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype - assert lora_a.dtype in [torch.float16, torch.bfloat16] - - device = qcurr_hidden_states.device - num_slices = len(lora_a_stacked) - - config = { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "SPLIT_K": split_k, - } - w1_lora_a_stacked = lora_a_stacked[0] - w1_lora_b_stacked = lora_b_stacked[0] - num_experts = lora_a_stacked[0].shape[1] - N = max_lora_rank - M = topk_weights.shape[0] - EM = sorted_token_ids.shape[1] - K = qcurr_hidden_states.shape[1] - num_tokens = M * top_k_num - w1_output_dim_size = w1_lora_b_stacked.shape[2] - - lora_intermediate_cache1 = torch.zeros( - (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), - dtype=output.dtype, - device=device, - ) - - # slices - a_intermediate_size = num_slices * M * top_k_num * max_lora_rank - a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view( - num_slices, M, top_k_num, max_lora_rank - ) - b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view( - num_slices, M, top_k_num, w1_output_dim_size - ) + shrink_config = { + "BLOCK_SIZE_M": shrink_block_size_m, + "BLOCK_SIZE_N": shrink_block_size_n, + "BLOCK_SIZE_K": shrink_block_size_k, + "GROUP_SIZE_M": shrink_group_size_m, + "num_warps": shrink_num_warps, + "num_stages": shrink_num_stages, + "SPLIT_K": shrink_split_k, + } b_ptr = _get_ptr(lora_a_stacked, device) grid = lambda META: ( - split_k + shrink_split_k * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_a_stacked), @@ -299,19 +261,68 @@ def _fused_moe_lora( num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, MUL_ROUTED_WEIGHT=False, - **config, + **shrink_config, ) + +@torch.inference_mode() +def _fused_moe_lora_expand( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + top_k_num: int, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: b_ptr = _get_ptr(lora_b_stacked, device) K = max_lora_rank N = w1_output_dim_size + w1_lora_b_stacked = lora_b_stacked[0] + a_intermediate_cache1 = a_intermediate_cache1.view( -1, a_intermediate_cache1.shape[3] ) - # Set split_k = 1 for expand calls - config["SPLIT_K"] = 1 + b_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, w1_output_dim_size), + dtype=torch.bfloat16, + device=device, + ) + + expand_config = { + "BLOCK_SIZE_M": expand_block_size_m, + "BLOCK_SIZE_N": expand_block_size_n, + "BLOCK_SIZE_K": expand_block_size_k, + "GROUP_SIZE_M": expand_group_size_m, + "num_warps": expand_num_warps, + "num_stages": expand_num_stages, + "SPLIT_K": expand_split_k, # Set split_k = 1 for expand calls + } + grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), @@ -348,12 +359,135 @@ def _fused_moe_lora( num_slice_c=num_slices, top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, - **config, + **expand_config, ) for i in range(num_slices): output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i] +@torch.inference_mode() +def _fused_moe_lora( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, N, max_lora_rank,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + max_lora_rank: int, + top_k_num: int, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: + assert len(lora_a_stacked) == len(lora_b_stacked) > 0 + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) + assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] + assert output.shape[0] == topk_weights.shape[0] + assert top_k_num == topk_weights.shape[1] + device = qcurr_hidden_states.device + num_slices = len(lora_a_stacked) + w1_lora_b_stacked = lora_b_stacked[0] + num_experts = lora_a_stacked[0].shape[1] + N = max_lora_rank + M = topk_weights.shape[0] + EM = sorted_token_ids.shape[1] + K = qcurr_hidden_states.shape[1] + num_tokens = M * top_k_num + w1_output_dim_size = w1_lora_b_stacked.shape[2] + + a_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, max_lora_rank), + dtype=torch.bfloat16, + device=device, + ) + + _fused_moe_lora_shrink( + a_intermediate_cache1, + qcurr_hidden_states, + lora_a_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + shrink_block_size_m, + shrink_block_size_n, + shrink_block_size_k, + shrink_group_size_m, + shrink_num_warps, + shrink_num_stages, + shrink_split_k, + mul_routed_weight, + ) + + _fused_moe_lora_expand( + output, + a_intermediate_cache1, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + max_lora_rank, + w1_output_dim_size, + expand_block_size_m, + expand_block_size_n, + expand_block_size_k, + expand_group_size_m, + expand_num_warps, + expand_num_stages, + expand_split_k, + mul_routed_weight, + ) + + def _fused_moe_lora_fake( output: torch.Tensor, qcurr_hidden_states: torch.Tensor, @@ -367,10 +501,80 @@ def _fused_moe_lora_fake( top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, - block_size_m: int, - block_size_n: int, - block_size_k: int, - group_size_m: int, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_shrink_fake( + a_intermediate_cache1: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_expand_fake( + output: torch.Tensor, + a_intermediate_cache1: torch.Tensor, + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, mul_routed_weight: bool = False, ) -> None: return @@ -383,7 +587,26 @@ def _fused_moe_lora_fake( mutates_args=["output"], fake_impl=_fused_moe_lora_fake, ) + + direct_register_custom_op( + op_name="fused_moe_lora_shrink", + op_func=_fused_moe_lora_shrink, + mutates_args=["a_intermediate_cache1"], + fake_impl=_fused_moe_lora_shrink_fake, + ) + + direct_register_custom_op( + op_name="fused_moe_lora_expand", + op_func=_fused_moe_lora_expand, + mutates_args=["output"], + fake_impl=_fused_moe_lora_expand_fake, + ) + fused_moe_lora = torch.ops.vllm.fused_moe_lora + fused_moe_lora_shrink = torch.ops.vllm.fused_moe_lora_shrink + fused_moe_lora_expand = torch.ops.vllm.fused_moe_lora_expand except AttributeError: fused_moe_lora = _fused_moe_lora + fused_moe_lora_shrink = _fused_moe_lora_shrink + fused_moe_lora_expand = _fused_moe_lora_expand diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 9ffb6dc3d85e..42a50a1da0bd 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -154,13 +154,13 @@ def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None: gpu_name = gpu_name.replace("-", "_") config_fname = None - if op_type == "shrink": - config_fname = f"{gpu_name}_{op_type.upper()}.json" - else: - assert op_type == "expand" + # only expand op needs to consider add_inputs + if op_type == "expand": config_fname = ( f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json" ) + else: + config_fname = f"{gpu_name}_{op_type.upper()}.json" config_path = Path(f"{user_defined_config_folder}/{config_fname}") if not config_path.exists(): @@ -186,8 +186,17 @@ def get_lora_op_configs( rank: int, num_slices: int, add_inputs: bool | None = None, + hidden_size_2: int | None = None, ) -> dict[str, int | None]: - assert op_type in ["shrink", "expand"] + # Add support for fused_moe_lora ops + assert op_type in [ + "shrink", + "expand", + "fused_moe_lora_gate_up_shrink", + "fused_moe_lora_gate_up_expand", + "fused_moe_lora_down_shrink", + "fused_moe_lora_down_expand", + ] # default config default = {} @@ -202,6 +211,21 @@ def get_lora_op_configs( "num_stages": 2, "max_nreg": None, } + # The default config for fused_moe_lora ops + elif op_type in [ + "fused_moe_lora_gate_up_shrink", + "fused_moe_lora_gate_up_expand", + "fused_moe_lora_down_shrink", + "fused_moe_lora_down_expand", + ]: + default = { + "block_m": 64, + "block_n": 64, + "block_k": 32, + "num_warps": 4, + "num_stages": 3, + "group_size_m": 8, + } else: default = { "block_m": 64, @@ -246,5 +270,13 @@ def get_lora_op_configs( or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))] ) + # slice by hidden_size_2 + if hidden_size_2 is not None: + n2 = hidden_size_2 + config_data = ( + config_data.get(str(n2)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n2))] + ) + assert config_data is not None return config_data diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index c552412cfd62..b6186e856152 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -479,7 +479,8 @@ def add_lora_fused_moe( num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, - config, + shrink_config, + expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, ): diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 30def90380db..164d481189ea 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -367,7 +367,8 @@ def add_lora_fused_moe( num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, - config, + shrink_config, + expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, ): @@ -388,10 +389,19 @@ def add_lora_fused_moe( top_k_num, lora_ids, adapter_enabled, - config["BLOCK_SIZE_M"], - config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], - config["GROUP_SIZE_M"], - config.get("SPLIT_K", 1), + shrink_config.get("BLOCK_SIZE_M", shrink_config.get("block_m")), + shrink_config.get("BLOCK_SIZE_N", shrink_config.get("block_n")), + shrink_config.get("BLOCK_SIZE_K", shrink_config.get("block_k")), + shrink_config.get("GROUP_SIZE_M", shrink_config.get("group_m")), + shrink_config.get("num_warps", 4), + shrink_config.get("num_stages", 1), + shrink_config.get("SPLIT_K", 1), + expand_config.get("BLOCK_SIZE_M", expand_config.get("block_m")), + expand_config.get("BLOCK_SIZE_N", expand_config.get("block_n")), + expand_config.get("BLOCK_SIZE_K", expand_config.get("block_k")), + expand_config.get("GROUP_SIZE_M", expand_config.get("group_m")), + expand_config.get("num_warps", 4), + expand_config.get("num_stages", 1), + expand_config.get("SPLIT_K", 1), mul_routed_weight, ) From d3364e98429e72cd15ef32f2e36c7fe0a4597171 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 21 Oct 2025 16:43:25 +0000 Subject: [PATCH 02/36] fix some bugs --- vllm/lora/layers/fused_moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 1712c17af89b..4ed03dbbeb9a 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -209,12 +209,12 @@ def wrapper(*args, **kwargs): # get the down expand config expand_config = get_lora_op_configs( op_type="fused_moe_lora_down_expand", - max_loras=self.w1_lora_a_stacked.shape[0], + max_loras=self.w2_lora_a_stacked.shape[0], batch=M, - hidden_size=self.w1_lora_a_stacked.shape[-1], - rank=self.w1_lora_a_stacked.shape[-2], + hidden_size=self.w2_lora_a_stacked.shape[-1], + rank=self.w2_lora_a_stacked.shape[-2], num_slices=1, - hidden_size_2=self.w1_lora_b_stacked.shape[-2], + hidden_size_2=self.w2_lora_b_stacked.shape[-2] ) else: get_config_func = functools.partial( From 9a5f9e0f6bfa859f749bab550b36b274721b8fac Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 21 Oct 2025 18:51:12 +0000 Subject: [PATCH 03/36] your message --- .pre-commit-config.yaml | 7 ------- vllm/lora/layers/fused_moe.py | 12 ++++++------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ac1fdfd6bce..2ae2a117a3c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,3 @@ -default_install_hook_types: - - pre-commit - - commit-msg -default_stages: - - commit # Run locally - - manual # Run in CI -exclude: 'vllm/third_party/.*' repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.1 diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 4ed03dbbeb9a..dc771629d664 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -116,11 +116,11 @@ def wrapper(*args, **kwargs): else: # fall back to the default config get_config_func = functools.partial( try_get_optimal_moe_config, - self.w13_weight.size(), - self.w2_weight.size(), + layer.w13_weight.size(), + layer.w2_weight.size(), top_k, config_dtype, - block_shape=self.quant_method.moe_quant_config.block_shape, + block_shape=layer.quant_method.moe_quant_config.block_shape, ) shrink_config = get_config_func(M) @@ -219,11 +219,11 @@ def wrapper(*args, **kwargs): else: get_config_func = functools.partial( try_get_optimal_moe_config, - self.w13_weight.size(), - self.w2_weight.size(), + layer.w13_weight.size(), + layer.w2_weight.size(), top_k, config_dtype, - block_shape=self.quant_method.moe_quant_config.block_shape, + block_shape=layer.quant_method.moe_quant_config.block_shape, ) shrink_config = get_config_func(M) expand_config = get_config_func(M) From 7d3071d7c78c48bf46aa6f3ef898bf9b14807b95 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 21 Oct 2025 18:52:46 +0000 Subject: [PATCH 04/36] fix bugs --- vllm/lora/layers/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index dc771629d664..309fedc75820 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -219,7 +219,7 @@ def wrapper(*args, **kwargs): else: get_config_func = functools.partial( try_get_optimal_moe_config, - layer.w13_weight.size(), + layer.w13_weight.size(), layer.w2_weight.size(), top_k, config_dtype, From 1724bdb05cc8c85059cf6a3288d57bb3b4bdc26e Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 21 Oct 2025 18:56:41 +0000 Subject: [PATCH 05/36] fix bugs --- .pre-commit-config.yaml | 150 ---------------------------------------- 1 file changed, 150 deletions(-) delete mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 2ae2a117a3c2..000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,150 +0,0 @@ -repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.1 - hooks: - - id: ruff-check - args: [--output-format, github, --fix] - - id: ruff-format -- repo: https://github.com/crate-ci/typos - rev: v1.38.1 - hooks: - - id: typos - args: [--force-exclude] -- repo: https://github.com/pre-commit/mirrors-clang-format - rev: v21.1.2 - hooks: - - id: clang-format - exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' - types_or: [c++, cuda] - args: [--style=file, --verbose] -- repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.45.0 - hooks: - - id: markdownlint - exclude: '.*\.inc\.md' - stages: [manual] # Only run in CI -- repo: https://github.com/rhysd/actionlint - rev: v1.7.7 - hooks: - - id: actionlint -- repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.9.1 - hooks: - - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28] - files: ^requirements/test\.(in|txt)$ -- repo: local - hooks: - - id: format-torch-nightly-test - name: reformat nightly_torch_test.txt to be in sync with test.in - language: python - entry: python tools/pre_commit/generate_nightly_torch_test.py - files: ^requirements/test\.(in|txt)$ - - id: mypy-local - name: Run mypy for local Python installation - entry: python tools/pre_commit/mypy.py 0 "local" - stages: [commit] # Don't run in CI - <<: &mypy_common - language: python - types_or: [python, pyi] - require_serial: true - additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.10 - entry: python tools/pre_commit/mypy.py 1 "3.10" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.11 - entry: python tools/pre_commit/mypy.py 1 "3.11" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.12 - entry: python tools/pre_commit/mypy.py 1 "3.12" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.13 - entry: python tools/pre_commit/mypy.py 1 "3.13" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: shellcheck - name: Lint shell scripts - entry: tools/pre_commit/shellcheck.sh - language: script - types: [shell] - - id: png-lint - name: Lint PNG exports from excalidraw - entry: tools/pre_commit/png-lint.sh - language: script - types: [png] - - id: signoff-commit - name: Sign-off Commit - entry: bash - args: - - -c - - | - if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then - printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)" - fi - language: system - verbose: true - stages: [commit-msg] - - id: check-spdx-header - name: Check SPDX headers - entry: python tools/pre_commit/check_spdx_header.py - language: python - types: [python] - - id: check-root-lazy-imports - name: Check root lazy imports - entry: python tools/pre_commit/check_init_lazy_imports.py - language: python - types: [python] - - id: check-filenames - name: Check for spaces in all filenames - entry: bash - args: - - -c - - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' - language: system - always_run: true - pass_filenames: false - - id: update-dockerfile-graph - name: Update Dockerfile dependency graph - entry: tools/pre_commit/update-dockerfile-graph.sh - language: script - - id: enforce-import-regex-instead-of-re - name: Enforce import regex as re - entry: python tools/pre_commit/enforce_regex_import.py - language: python - types: [python] - pass_filenames: false - additional_dependencies: [regex] - # forbid directly import triton - - id: forbid-direct-triton-import - name: "Forbid direct 'import triton'" - entry: python tools/pre_commit/check_triton_import.py - language: python - types: [python] - pass_filenames: false - additional_dependencies: [regex] - - id: check-pickle-imports - name: Prevent new pickle/cloudpickle imports - entry: python tools/pre_commit/check_pickle_imports.py - language: python - types: [python] - additional_dependencies: [regex] - - id: validate-config - name: Validate configuration has default values and that each field has a docstring - entry: python tools/pre_commit/validate_config.py - language: python - additional_dependencies: [regex] - # Keep `suggestion` last - - id: suggestion - name: Suggestion - entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=."' - language: system - verbose: true - pass_filenames: false - # Insert new entries above the `suggestion` entry From e8d144f5c113ed4a1dc33dc555c3dc4a53b4bd95 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 21 Oct 2025 19:23:33 +0000 Subject: [PATCH 06/36] Fixed the bugs --- .pre-commit-config.yaml | 157 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000000..fdff02f2f8d3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,157 @@ +default_install_hook_types: + - pre-commit + - commit-msg +default_stages: + - pre-commit # Run locally + - manual # Run in CI +exclude: 'vllm/third_party/.*' +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 + hooks: + - id: ruff-check + args: [--output-format, github, --fix] + - id: ruff-format +- repo: https://github.com/crate-ci/typos + rev: v1.38.1 + hooks: + - id: typos + args: [--force-exclude] +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v21.1.2 + hooks: + - id: clang-format + exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' + types_or: [c++, cuda] + args: [--style=file, --verbose] +- repo: https://github.com/igorshubovych/markdownlint-cli + rev: v0.45.0 + hooks: + - id: markdownlint + exclude: '.*\.inc\.md' + stages: [manual] # Only run in CI +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint +- repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.9.1 + hooks: + - id: pip-compile + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] + files: ^requirements/test\.(in|txt)$ +- repo: local + hooks: + - id: format-torch-nightly-test + name: reformat nightly_torch_test.txt to be in sync with test.in + language: python + entry: python tools/generate_nightly_torch_test.py + files: ^requirements/test\.(in|txt)$ + - id: mypy-local + name: Run mypy for local Python installation + entry: python tools/pre_commit/mypy.py 0 "local" + stages: [commit] # Don't run in CI + <<: &mypy_common + language: python + types_or: [python, pyi] + require_serial: true + additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] + - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.10 + entry: python tools/pre_commit/mypy.py 1 "3.10" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.11 + entry: python tools/pre_commit/mypy.py 1 "3.11" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.12 + entry: python tools/pre_commit/mypy.py 1 "3.12" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.13 + entry: python tools/pre_commit/mypy.py 1 "3.13" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: shellcheck + name: Lint shell scripts + entry: tools/shellcheck.sh + language: script + types: [shell] + - id: png-lint + name: Lint PNG exports from excalidraw + entry: tools/png-lint.sh + language: script + types: [png] + - id: signoff-commit + name: Sign-off Commit + entry: bash + args: + - -c + - | + if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then + printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)" + fi + language: system + verbose: true + stages: [commit-msg] + - id: check-spdx-header + name: Check SPDX headers + entry: python tools/check_spdx_header.py + language: python + types: [python] + - id: check-root-lazy-imports + name: Check root lazy imports + entry: python tools/check_init_lazy_imports.py + language: python + types: [python] + - id: check-filenames + name: Check for spaces in all filenames + entry: bash + args: + - -c + - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + language: system + always_run: true + pass_filenames: false + - id: update-dockerfile-graph + name: Update Dockerfile dependency graph + entry: tools/update-dockerfile-graph.sh + language: script + - id: enforce-import-regex-instead-of-re + name: Enforce import regex as re + entry: python tools/enforce_regex_import.py + language: python + types: [python] + pass_filenames: false + additional_dependencies: [regex] + # forbid directly import triton + - id: forbid-direct-triton-import + name: "Forbid direct 'import triton'" + entry: python tools/check_triton_import.py + language: python + types: [python] + pass_filenames: false + additional_dependencies: [regex] + - id: check-pickle-imports + name: Prevent new pickle/cloudpickle imports + entry: python tools/pre_commit/check_pickle_imports.py + language: python + types: [python] + additional_dependencies: [regex] + - id: validate-config + name: Validate configuration has default values and that each field has a docstring + entry: python tools/validate_config.py + language: python + additional_dependencies: [regex] + # Keep `suggestion` last + - id: suggestion + name: Suggestion + entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=."' + language: system + verbose: true + pass_filenames: false + # Insert new entries above the `suggestion` entry From 64537e28866090592124c639b35471de1248452d Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 21 Oct 2025 19:30:31 +0000 Subject: [PATCH 07/36] Adding pre-commit-config.yaml --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fdff02f2f8d3..121bdb750de5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: - id: mypy-local name: Run mypy for local Python installation entry: python tools/pre_commit/mypy.py 0 "local" - stages: [commit] # Don't run in CI + stages: [pre-commit] # Don't run in CI <<: &mypy_common language: python types_or: [python, pyi] From b09466c9568f1dcbca19e4488f7d9509648e8a07 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Thu, 23 Oct 2025 04:59:44 +0000 Subject: [PATCH 08/36] clean the code --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 121bdb750de5..92045ee9a856 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: rev: 0.9.1 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28] files: ^requirements/test\.(in|txt)$ - repo: local hooks: From d9cb74189ea3a23e0b14853a261b8055073f9f84 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Thu, 23 Oct 2025 06:45:22 +0000 Subject: [PATCH 09/36] fix bugs --- vllm/lora/layers/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 309fedc75820..8efbb1cacf37 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -214,12 +214,12 @@ def wrapper(*args, **kwargs): hidden_size=self.w2_lora_a_stacked.shape[-1], rank=self.w2_lora_a_stacked.shape[-2], num_slices=1, - hidden_size_2=self.w2_lora_b_stacked.shape[-2] + hidden_size_2=self.w2_lora_b_stacked.shape[-2], ) else: get_config_func = functools.partial( try_get_optimal_moe_config, - layer.w13_weight.size(), + layer.w13_weight.size(), layer.w2_weight.size(), top_k, config_dtype, From 0bf6a533f0a329ca790f5d98e544e14be25ecc7b Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sun, 26 Oct 2025 22:30:52 +0000 Subject: [PATCH 10/36] Adding support in benchmark_lora for fused_moe_lora expand and shrink kernel benchmarking Signed-off-by: Yu Gong --- benchmarks/kernels/benchmark_lora.py | 17 ++++++++++++++++- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 1 + 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 96da0a5ea79e..890116ffb0f2 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -620,6 +620,12 @@ def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, in if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else self.lora_kernel_meta.token_lora_mapping.shape[0] ) + ## TODO: test if this works + num_tokens = ( + self.lora_kernel_meta.token_lora_mapping.shape[0] * ctx.top_k_num + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else self.lora_kernel_meta.token_lora_mapping.shape[0] + ) max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices @@ -716,6 +722,7 @@ def as_lora_shrink_kwargs( self.sanity_check(ctx, op_type) self.to_device(self.input.device) + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. @@ -943,17 +950,23 @@ def as_fused_moe_lora_expand_kwargs( def bench_fn_kwargs( self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None ) -> dict[str, Any]: - if op_type.is_shrink_fn(): + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert add_inputs is None else: assert add_inputs is not None if op_type == OpType.LORA_SHRINK: return self.as_lora_shrink_kwargs(ctx, op_type) + return self.as_lora_shrink_kwargs(ctx, op_type) if op_type == OpType.LORA_EXPAND: return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) if op_type.is_fused_moe_lora_shrink_fn(): return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) + if op_type.is_fused_moe_lora_expand_fn(): + return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) + return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) + if op_type.is_fused_moe_lora_shrink_fn(): + return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) if op_type.is_fused_moe_lora_expand_fn(): return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) raise ValueError(f"Unrecognized optype {self}") @@ -1011,6 +1024,7 @@ def bench_optype( ] for bt in bench_tensors: bt.sanity_check(ctx, op_type) + bt.sanity_check(ctx, op_type) # Test correctness of our implementation. if test_correctness: @@ -1028,6 +1042,7 @@ def bench_optype( _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() _LORA_PTR_DICT.clear() + _LORA_PTR_DICT.clear() # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index ae59d73487b9..629ab605a5b4 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -209,6 +209,7 @@ def _fused_moe_lora_shrink( mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] + w1_lora_a_stacked = lora_a_stacked[0] shrink_config = { "BLOCK_SIZE_M": shrink_block_size_m, From 94508e4ee38493f1efaab6da37e1bb990353dc52 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sun, 26 Oct 2025 23:12:08 +0000 Subject: [PATCH 11/36] Adding data generation for fused_moe_lora Signed-off-by: Yu Gong --- benchmarks/kernels/benchmark_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 890116ffb0f2..d5fdd5ccfbbf 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -21,6 +21,7 @@ from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.triton_utils import HAS_TRITON, triton +from vllm.triton_utils import HAS_TRITON, triton if HAS_TRITON: from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora From 6c8c97b321758b8745751f507bc72da2272ad8ec Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 05:16:06 +0000 Subject: [PATCH 12/36] Fix bugs --- benchmarks/kernels/benchmark_lora.py | 16 ---------------- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 1 - 2 files changed, 17 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index d5fdd5ccfbbf..fd541ebc8dee 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -21,7 +21,6 @@ from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.triton_utils import HAS_TRITON, triton -from vllm.triton_utils import HAS_TRITON, triton if HAS_TRITON: from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora @@ -621,12 +620,6 @@ def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, in if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else self.lora_kernel_meta.token_lora_mapping.shape[0] ) - ## TODO: test if this works - num_tokens = ( - self.lora_kernel_meta.token_lora_mapping.shape[0] * ctx.top_k_num - if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] - else self.lora_kernel_meta.token_lora_mapping.shape[0] - ) max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices @@ -723,7 +716,6 @@ def as_lora_shrink_kwargs( self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata(ctx, op_type) _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. @@ -958,16 +950,10 @@ def bench_fn_kwargs( if op_type == OpType.LORA_SHRINK: return self.as_lora_shrink_kwargs(ctx, op_type) - return self.as_lora_shrink_kwargs(ctx, op_type) if op_type == OpType.LORA_EXPAND: return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) if op_type.is_fused_moe_lora_shrink_fn(): return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) - if op_type.is_fused_moe_lora_expand_fn(): - return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) - return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) - if op_type.is_fused_moe_lora_shrink_fn(): - return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) if op_type.is_fused_moe_lora_expand_fn(): return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) raise ValueError(f"Unrecognized optype {self}") @@ -1025,7 +1011,6 @@ def bench_optype( ] for bt in bench_tensors: bt.sanity_check(ctx, op_type) - bt.sanity_check(ctx, op_type) # Test correctness of our implementation. if test_correctness: @@ -1043,7 +1028,6 @@ def bench_optype( _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() _LORA_PTR_DICT.clear() - _LORA_PTR_DICT.clear() # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 629ab605a5b4..ae59d73487b9 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -209,7 +209,6 @@ def _fused_moe_lora_shrink( mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] - w1_lora_a_stacked = lora_a_stacked[0] shrink_config = { "BLOCK_SIZE_M": shrink_block_size_m, From 8e85f95b648ee3ceee8ba11d3f5f59aab9403d99 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 06:08:33 +0000 Subject: [PATCH 13/36] Adding accuracy test --- benchmarks/kernels/benchmark_lora.py | 189 +++++++++++++++++++++++++-- 1 file changed, 181 insertions(+), 8 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index fd541ebc8dee..bc267a0ab96b 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -386,6 +386,137 @@ def bench_fn(self) -> Callable: raise ValueError(f"Unrecognized optype {self}") + def _run_fused_moe_lora_ref( + self, + output: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + is_shrink: bool, + **kwargs, + ) -> None: + """ + Unified reference implementation for fused MoE LoRA operations. + + Processes tokens exactly as the kernel does: + - For each LoRA and block: get expert_id from expert_ids tensor + - For each token in block: get token_id from sorted_token_ids + - Perform the gemm with the corresponding expert's weights + """ + top_k_num = kwargs.get("top_k_num", 1) + seq_lens_cpu = kwargs.get("seq_lens_cpu") + prompt_lora_mapping_cpu = kwargs.get("prompt_lora_mapping_cpu") + scaling = kwargs.get("scaling", 1.0) + topk_weights = kwargs.get("topk_weights") + sorted_token_ids = kwargs.get("sorted_token_ids") # (num_loras, padded_size) + expert_ids = kwargs.get("expert_ids") # (num_loras, num_blocks) + mul_routed_weight = kwargs.get("mul_routed_weight", False) + w_dtype = lora_weights[0].dtype + num_slices = len(lora_weights) + + # Get block size from kernel config (needed to map tokens to blocks) + block_size_m = kwargs.get( + "shrink_block_size_m" if is_shrink else "expand_block_size_m", 64 + ) + + # Move to CPU for easier processing + sorted_token_ids_cpu = sorted_token_ids.cpu() + expert_ids_cpu = expert_ids.cpu() + num_loras = lora_weights[0].shape[0] + + # Process each LoRA + for lora_idx in range(num_loras): + # Find which batch uses this LoRA + batch_mask = prompt_lora_mapping_cpu == lora_idx + if not batch_mask.any(): + continue # No sequences use this LoRA + + # Process each slice + for slice_idx in range(num_slices): + weights_slice = lora_weights[ + slice_idx + ] # (num_loras, num_experts, out_dim, in_dim) + + # Process each block for this LoRA + num_blocks = expert_ids_cpu.shape[1] + for block_idx in range(num_blocks): + # Get the expert_id for this block + expert_id = expert_ids_cpu[lora_idx, block_idx].item() + if expert_id == -1: + continue # Empty block + + # Get weight for this expert and LoRA + weight = weights_slice[lora_idx, expert_id, :, :] + + # Process tokens in this block + block_start = block_idx * block_size_m + block_end = min( + block_start + block_size_m, sorted_token_ids_cpu.shape[1] + ) + + for token_pos in range(block_start, block_end): + sorted_token_id = sorted_token_ids_cpu[ + lora_idx, token_pos + ].item() + + # Check if this is a valid token (not padding) + num_tokens = seq_lens_cpu.sum().item() + if self == OpType.FUSED_MOE_LORA_DOWN_SHRINK and is_shrink: + max_valid = num_tokens * top_k_num + else: + max_valid = num_tokens + + if sorted_token_id >= max_valid: + continue # Padding token + + # Decode: original_token_idx and k_idx from sorted_token_id + if is_shrink: + # For shrink: sorted_token_id encodes (token_idx * top_k + k_idx) + original_token_idx = sorted_token_id // top_k_num + k_idx = sorted_token_id % top_k_num + else: + # For expand: sorted_token_id is just the token index + # k_idx comes from the expert routing (encoded in block structure) + original_token_idx = sorted_token_id + # Need to infer k_idx - in expand, tokens are organized differently + # Each expert processes tokens with that expert's k_idx + k_idx = expert_id % top_k_num # Simplified + + if is_shrink: + # Input index + if self == OpType.FUSED_MOE_LORA_DOWN_SHRINK: + input_idx = sorted_token_id + else: + input_idx = original_token_idx + + x = input[input_idx : input_idx + 1, :] + result = torch.nn.functional.linear(x, weight) + result = result * scaling + output[slice_idx, original_token_idx, k_idx, :] = ( + result.squeeze(0) + ) + + else: + # Expand + hidden_size = weights_slice.shape[2] + x = ( + input[slice_idx, original_token_idx, k_idx, :] + .unsqueeze(0) + .to(dtype=w_dtype) + ) + result = torch.nn.functional.linear(x, weight) + result = result * scaling + + if mul_routed_weight and topk_weights is not None: + route_weight = topk_weights[original_token_idx, k_idx] + result = result * route_weight + + slice_offset = slice_idx * hidden_size + output[ + original_token_idx, + k_idx, + slice_offset : slice_offset + hidden_size, + ] += result.squeeze(0) + def run_ref_group_gemm( self, output: torch.Tensor, @@ -418,6 +549,24 @@ def run_ref_group_gemm( lora_weights=lora_weights[slice_idx], **kwargs, ) + elif self.is_fused_moe_lora_shrink_fn(): + # For fused MoE LoRA shrink: input @ lora_a.T -> intermediate + # Input shape: (num_tokens, hidden_size) for gate_up or + # (num_tokens*top_k, hidden_size) for down + # Weight shape: (num_loras, num_experts, lora_rank, hidden_size) + # Output shape: (num_slices, num_tokens, top_k, lora_rank) + self._run_fused_moe_lora_ref( + output, input, lora_weights, is_shrink=True, **kwargs + ) + + elif self.is_fused_moe_lora_expand_fn(): + # For fused MoE LoRA expand: intermediate @ lora_b.T -> output + # Input shape: (num_slices, num_tokens, top_k, lora_rank) + # Weight shape: (num_loras, num_experts, hidden_size, lora_rank) + # Output shape: (num_tokens, top_k, hidden_size * num_slices) + self._run_fused_moe_lora_ref( + output, input, lora_weights, is_shrink=False, **kwargs + ) else: raise ValueError(f"Unrecognized optype {self}") @@ -959,7 +1108,7 @@ def bench_fn_kwargs( raise ValueError(f"Unrecognized optype {self}") def test_correctness( - self, op_type: OpType, expand_fn_add_inputs: bool | None + self, ctx: BenchmarkContext, op_type: OpType, expand_fn_add_inputs: bool | None ) -> bool: """ Test correctness of op_type implementation against a grouped gemm @@ -970,16 +1119,37 @@ def test_correctness( ref_output = self.output.clone() self.output.zero_() - op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + kernel_kwargs = self.bench_fn_kwargs(ctx, op_type, expand_fn_add_inputs) + op_type.bench_fn()(**kernel_kwargs) + + # Build reference kwargs + ref_kwargs = { + "seq_lens_cpu": seq_lens_cpu, + "prompt_lora_mapping_cpu": prompt_lora_mapping_cpu, + "scaling": 1.0, + "add_inputs": expand_fn_add_inputs, + } + + # Add fused_moe_lora specific kwargs if needed + if op_type.is_fused_moe_lora_fn(): + ref_kwargs.update( + { + "topk_weights": kernel_kwargs.get("topk_weights"), + "sorted_token_ids": kernel_kwargs.get("sorted_token_ids"), + "expert_ids": kernel_kwargs.get("expert_ids"), + "top_k_num": ctx.top_k_num, + "num_experts": ctx.num_experts, + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + "shrink_block_size_m": kernel_kwargs.get("shrink_block_size_m"), + "expand_block_size_m": kernel_kwargs.get("expand_block_size_m"), + } + ) op_type.run_ref_group_gemm( ref_output, self.input, self.lora_weights_lst, - seq_lens_cpu=seq_lens_cpu, - prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, - scaling=1.0, - add_inputs=expand_fn_add_inputs, + **ref_kwargs, ) rtol, atol = { @@ -1015,12 +1185,15 @@ def bench_optype( # Test correctness of our implementation. if test_correctness: assert all( - [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] + [ + bt.test_correctness(ctx, op_type, expand_fn_add_inputs) + for bt in bench_tensors + ] ) # BenchmarkTensors -> dict (kwargs) kwargs_list = [ - bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) + bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors ] From 0ee933b7ed585228b67e111376c43c3817f38746 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 17:35:45 +0000 Subject: [PATCH 14/36] fix bugs --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index ae59d73487b9..222282dbf785 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -177,8 +177,7 @@ def _fused_moe_lora_kernel( @torch.inference_mode() def _fused_moe_lora_shrink( - a_intermediate_cache1: - torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + a_intermediate_cache1: torch.Tensor, # (num_slices, num_tokens, top_k_num, max_lora_rank) qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) lora_a_stacked: list[ torch.Tensor @@ -466,6 +465,7 @@ def _fused_moe_lora( sorted_token_ids, expert_ids, num_tokens_post_padded, + top_k_num, ## adding for kernel device, N, From f9f0f8e803b30525610d99c17983ca75ce955ffe Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 20:47:38 +0000 Subject: [PATCH 15/36] fix bugs --- benchmarks/kernels/benchmark_lora.py | 87 ++++++++++++++++++-------- vllm/lora/layers/fused_moe.py | 3 +- vllm/lora/ops/triton_ops/utils.py | 13 ++-- vllm/lora/punica_wrapper/punica_gpu.py | 24 +++---- 4 files changed, 81 insertions(+), 46 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index bc267a0ab96b..e9cf3bb48206 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -70,6 +70,8 @@ DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SEQ_LENGTHS = [1] DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] +DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k +DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts # Utilities @@ -315,20 +317,19 @@ def matmul_shapes_fused_moe_lora( top_k_num: int, num_experts: int, ) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]: - if self.is_fused_moe_lora_gate_up_fn(): - if self.is_fused_moe_lora_shrink_fn(): - input_shape = ( - (m * top_k_num, n) - if self in [OpType.FUSED_MOE_LORA_GATE_UP_SHRINK] - else (m, n) - ) - output_shape = (num_slices, m, top_k_num, k) - weight_shape = (num_loras, num_experts, k, n) - else: - assert self.is_fused_moe_lora_expand_fn() - input_shape = (num_slices, m, top_k_num, k) - output_shape = (m, top_k_num, n * num_slices) - weight_shape = (num_loras, num_experts, n, k) + if self.is_fused_moe_lora_shrink_fn(): + input_shape = ( + (m * top_k_num, n) + if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else (m, n) + ) + output_shape = (num_slices, m, top_k_num, k) + weight_shape = (num_loras, num_experts, k, n) + else: + assert self.is_fused_moe_lora_expand_fn() + input_shape = (num_slices, m, top_k_num, k) + output_shape = (m, top_k_num, n * num_slices) + weight_shape = (num_loras, num_experts, n, k) return (input_shape, weight_shape, output_shape) def matmul_shapes( @@ -357,7 +358,6 @@ def matmul_shapes( return ((num_slices, m, k), b_shape, (m, n * num_slices)) if self.is_fused_moe_lora_fn(): return self.matmul_shapes_fused_moe_lora( - self, m, k, n, @@ -470,14 +470,14 @@ def _run_fused_moe_lora_ref( # Decode: original_token_idx and k_idx from sorted_token_id if is_shrink: - # For shrink: sorted_token_id encodes (token_idx * top_k + k_idx) + # shrink: sorted_token_id (token_idx * top_k + k_idx) original_token_idx = sorted_token_id // top_k_num k_idx = sorted_token_id % top_k_num else: # For expand: sorted_token_id is just the token index - # k_idx comes from the expert routing (encoded in block structure) + # k_idx from the expert routing (encoded in block structure) original_token_idx = sorted_token_id - # Need to infer k_idx - in expand, tokens are organized differently + # Infer k_idx, in expand, tokens are organized differently # Each expert processes tokens with that expert's k_idx k_idx = expert_id % top_k_num # Simplified @@ -719,7 +719,7 @@ def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: torch.sum(self.seq_lens) * ctx.top_k_num == num_tokens if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else torch.sum(self.seq_lens) == num_tokens - ) + ), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}" num_seqs = self.seq_lens.shape[0] # assert self.seq_start_loc.shape[0] == num_seqs ## In down shrink case, each prompt corresponds to top_k_num sequences @@ -846,7 +846,7 @@ def moe_lora_align_block_size( (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = ( moe_lora_align_block_size( - curr_topk_ids=curr_topk_ids, + topk_ids=curr_topk_ids, token_lora_mapping=token_lora_mapping, block_size=block_size, num_experts=ctx.num_experts, @@ -968,7 +968,12 @@ def as_fused_moe_lora_shrink_kwargs( lora_rank = lw_shape[-2] # Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank] assert len(o_shape) == 4 - assert o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank) + assert ( + o_shape + == (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank) + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank) + ) kernel_config = get_lora_op_configs( op_type.name.lower(), max_loras=lw_shape[0], @@ -1010,7 +1015,7 @@ def as_fused_moe_lora_shrink_kwargs( "shrink_group_size_m": kernel_config["GROUP_SIZE_M"], "shrink_num_warps": kernel_config["num_warps"], "shrink_num_stages": kernel_config["num_stages"], - "shrink_splitK": kernel_config.get("SPLIT_K", 1), + "shrink_split_k": kernel_config.get("SPLIT_K", 1), "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), } @@ -1085,7 +1090,7 @@ def as_fused_moe_lora_expand_kwargs( "expand_group_size_m": kernel_config["GROUP_SIZE_M"], "expand_num_warps": kernel_config["num_warps"], "expand_num_stages": kernel_config["num_stages"], - "expand_splitK": kernel_config.get("SPLIT_K", 1), + "expand_split_k": kernel_config.get("SPLIT_K", 1), "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), } @@ -1170,7 +1175,7 @@ def bench_optype( test_correctness: bool = False, ) -> TMeasurement: assert arg_pool_size >= 1 - if op_type.is_shrink_fn(): + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert expand_fn_add_inputs is None else: assert expand_fn_add_inputs is not None @@ -1350,7 +1355,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): # Benchmark bench_op expand_fn_add_inputs = ( - [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + [None] + if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn() + else args.expand_fn_add_inputs ) for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( @@ -1388,12 +1395,22 @@ def as_benchmark_contexts( hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace ) -> list[BenchmarkContext]: ctxs: list[BenchmarkContext] = [] - for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa + for ( + batch_size, + hidden_size, + lora_rank, + num_loras, + sort_by_lora_id, + top_k_num, + num_experts, + ) in product( # noqa args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, args.sort_by_lora_id, + args.top_k_nums, + args.num_experts, ): ctxs.append( BenchmarkContext( @@ -1409,6 +1426,8 @@ def as_benchmark_contexts( sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, # To be filled based on the OpType to benchmark + top_k_num=top_k_num, + num_experts=num_experts, num_slices=None, ) ) @@ -1569,6 +1588,22 @@ def add_common_command_args(p: argparse.ArgumentParser): ), ) + p.add_argument( + "--top-k-nums", + nargs="+", + type=int, + default=DEFAULT_TOP_K_NUMS, + help="Top-K values for MoE LoRA operations", + ) + + p.add_argument( + "--num-experts", + nargs="+", + type=int, + default=DEFAULT_NUM_EXPERTS, + help="Number of experts for MoE LoRA operations", + ) + parser = FlexibleArgumentParser( description=f""" Benchmark LoRA kernels: diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 8efbb1cacf37..b59a30f2217b 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -128,8 +128,7 @@ def wrapper(*args, **kwargs): # get the block size of m from customized config or default config max_loras = self.w1_lora_a_stacked.shape[0] block_size = ( - shrink_config.get("BLOCK_SIZE_M", shrink_config.get("block_m", 64)) - or 64 + shrink_config.get("BLOCK_SIZE_M", 64) ) ( sorted_token_ids_lora, diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 42a50a1da0bd..a7c45d79d81e 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -219,12 +219,13 @@ def get_lora_op_configs( "fused_moe_lora_down_expand", ]: default = { - "block_m": 64, - "block_n": 64, - "block_k": 32, - "num_warps": 4, - "num_stages": 3, - "group_size_m": 8, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "NUM_WARPS": 4, + "NUM_STAGES": 3, + "GROUP_SIZE_M": 8, + "SPLIT_K": 1, } else: default = { diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 164d481189ea..295a10c7083d 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -389,19 +389,19 @@ def add_lora_fused_moe( top_k_num, lora_ids, adapter_enabled, - shrink_config.get("BLOCK_SIZE_M", shrink_config.get("block_m")), - shrink_config.get("BLOCK_SIZE_N", shrink_config.get("block_n")), - shrink_config.get("BLOCK_SIZE_K", shrink_config.get("block_k")), - shrink_config.get("GROUP_SIZE_M", shrink_config.get("group_m")), - shrink_config.get("num_warps", 4), - shrink_config.get("num_stages", 1), + shrink_config.get("BLOCK_SIZE_M", 64), + shrink_config.get("BLOCK_SIZE_N", 64), + shrink_config.get("BLOCK_SIZE_K", 32), + shrink_config.get("GROUP_SIZE_M", 8), + shrink_config.get("NUM_WARPS", 4), + shrink_config.get("NUM_STAGES", 3), shrink_config.get("SPLIT_K", 1), - expand_config.get("BLOCK_SIZE_M", expand_config.get("block_m")), - expand_config.get("BLOCK_SIZE_N", expand_config.get("block_n")), - expand_config.get("BLOCK_SIZE_K", expand_config.get("block_k")), - expand_config.get("GROUP_SIZE_M", expand_config.get("group_m")), - expand_config.get("num_warps", 4), - expand_config.get("num_stages", 1), + expand_config.get("BLOCK_SIZE_M", 64), + expand_config.get("BLOCK_SIZE_N", 64), + expand_config.get("BLOCK_SIZE_K", 64), + expand_config.get("GROUP_SIZE_M", 64), + expand_config.get("NUM_WARPS", 4), + expand_config.get("NUM_STAGES", 3), expand_config.get("SPLIT_K", 1), mul_routed_weight, ) From e11030e7076bed14329942ab73d80168ed44d1c9 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 21:02:30 +0000 Subject: [PATCH 16/36] clean code --- .pre-commit-config.yaml | 4 +- vllm/lora/layers/fused_moe.py | 135 +++++++++++++++++----------------- 2 files changed, 68 insertions(+), 71 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92045ee9a856..fbfd8016cb76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,8 +48,8 @@ repos: entry: python tools/generate_nightly_torch_test.py files: ^requirements/test\.(in|txt)$ - id: mypy-local - name: Run mypy for local Python installation - entry: python tools/pre_commit/mypy.py 0 "local" + name: Run mypy locally for lowest supported Python version + entry: python tools/pre_commit/mypy.py 0 "3.10" stages: [pre-commit] # Don't run in CI <<: &mypy_common language: python diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index b59a30f2217b..bcc7c76d7927 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -40,6 +40,50 @@ def __init__(self, base_layer: FusedMoE) -> None: self.device = base_layer.w2_weight.device self._inject_lora_into_fused_moe() + def _get_lora_moe_configs( + self, + op_prefix: str, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + num_slices: int, + M: int, + layer: FusedMoE, + top_k: int, + config_dtype: str, + ): + if envs.VLLM_TUNED_CONFIG_FOLDER: + shrink_config = get_lora_op_configs( + op_type=f"fused_moe_lora_{op_prefix}_shrink", + max_loras=lora_a_stacked.shape[0], + batch=M, + hidden_size=lora_a_stacked.shape[-1], + rank=lora_a_stacked.shape[-2], + num_slices=num_slices, + hidden_size_2=lora_b_stacked.shape[-2], + ) + expand_config = get_lora_op_configs( + op_type=f"fused_moe_lora_{op_prefix}_expand", + max_loras=lora_a_stacked.shape[0], + batch=M, + hidden_size=lora_a_stacked.shape[-1], + rank=lora_a_stacked.shape[-2], + num_slices=num_slices, + hidden_size_2=lora_b_stacked.shape[-2], + ) + else: # fall back to the default config + get_config_func = functools.partial( + try_get_optimal_moe_config, + layer.w13_weight.size(), + layer.w2_weight.size(), + top_k, + config_dtype, + block_shape=layer.quant_method.moe_quant_config.block_shape, + ) + shrink_config = get_config_func(M) + expand_config = get_config_func(M) + + return shrink_config, expand_config + def _inject_lora_into_fused_moe(self): moe_state_dict = {} top_k = self.base_layer.top_k @@ -91,45 +135,20 @@ def wrapper(*args, **kwargs): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - ## if the env var is set, loading the config - if envs.VLLM_TUNED_CONFIG_FOLDER: - # get the gate/up shrink config - shrink_config = get_lora_op_configs( - op_type="fused_moe_lora_gate_up_shrink", - max_loras=self.w1_lora_a_stacked.shape[0], - batch=M, - hidden_size=self.w1_lora_a_stacked.shape[-1], - rank=self.w1_lora_a_stacked.shape[-2], - num_slices=2, - hidden_size_2=self.w1_lora_b_stacked.shape[-2], - ) - # get the gate/up expand config - expand_config = get_lora_op_configs( - op_type="fused_moe_lora_gate_up_expand", - max_loras=self.w1_lora_a_stacked.shape[0], - batch=M, - hidden_size=self.w1_lora_a_stacked.shape[-1], - rank=self.w1_lora_a_stacked.shape[-2], - num_slices=2, - hidden_size_2=self.w1_lora_b_stacked.shape[-2], - ) - else: # fall back to the default config - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, - ) - - shrink_config = get_config_func(M) - expand_config = get_config_func(M) ## same as the shrink config + shrink_config, expand_config = self._get_lora_moe_configs( + op_prefix="gate_up", + lora_a_stacked=self.w1_lora_a_stacked, + lora_b_stacked=self.w1_lora_b_stacked, + num_slices=2, + M=M, + layer=layer, + top_k=top_k, + config_dtype=config_dtype, + ) + # get the block size of m from customized config or default config max_loras = self.w1_lora_a_stacked.shape[0] - block_size = ( - shrink_config.get("BLOCK_SIZE_M", 64) - ) + block_size = shrink_config.get("BLOCK_SIZE_M", 64) ( sorted_token_ids_lora, expert_ids_lora, @@ -194,38 +213,16 @@ def wrapper(*args, **kwargs): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - if envs.VLLM_TUNED_CONFIG_FOLDER: - # get the down shrink config - shrink_config = get_lora_op_configs( - op_type="fused_moe_lora_down_shrink", - max_loras=self.w2_lora_a_stacked.shape[0], - batch=M, - hidden_size=self.w2_lora_a_stacked.shape[-1], - rank=self.w2_lora_a_stacked.shape[-2], - num_slices=1, - hidden_size_2=self.w2_lora_b_stacked.shape[-2], - ) - # get the down expand config - expand_config = get_lora_op_configs( - op_type="fused_moe_lora_down_expand", - max_loras=self.w2_lora_a_stacked.shape[0], - batch=M, - hidden_size=self.w2_lora_a_stacked.shape[-1], - rank=self.w2_lora_a_stacked.shape[-2], - num_slices=1, - hidden_size_2=self.w2_lora_b_stacked.shape[-2], - ) - else: - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, - ) - shrink_config = get_config_func(M) - expand_config = get_config_func(M) + shrink_config, expand_config = self._get_lora_moe_configs( + op_prefix="down", + lora_a_stacked=self.w2_lora_a_stacked, + lora_b_stacked=self.w2_lora_b_stacked, + num_slices=1, + M=M, + layer=layer, + top_k=top_k, + config_dtype=config_dtype, + ) sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"] expert_ids_lora = moe_state_dict["expert_ids_lora"] From 82635d507be29cb80b922730a462d7aea0f583b1 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 21:06:57 +0000 Subject: [PATCH 17/36] fix pre-commit Signed-off-by: Yu Gong --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 222282dbf785..783227473ca1 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -177,7 +177,8 @@ def _fused_moe_lora_kernel( @torch.inference_mode() def _fused_moe_lora_shrink( - a_intermediate_cache1: torch.Tensor, # (num_slices, num_tokens, top_k_num, max_lora_rank) + a_intermediate_cache1: torch.Tensor, + # (num_slices, num_tokens, top_k_num, max_lora_rank) qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) lora_a_stacked: list[ torch.Tensor From 5553bda56a908f4fe23d8db946f9d1777c0b7a82 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 28 Oct 2025 23:41:32 +0000 Subject: [PATCH 18/36] fix bugs Signed-off-by: Yu Gong --- tests/lora/test_fused_moe_lora_kernel.py | 11 +++++++++++ vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 318a0e58805d..91ab4a87c65f 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -158,6 +158,8 @@ def use_fused_moe_lora_kernel( "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "NUM_WARPS": 4, + "NUM_STAGES": 3, "SPLIT_K": 1, } @@ -182,6 +184,15 @@ def use_fused_moe_lora_kernel( config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], config["SPLIT_K"], mul_routed_weight, ) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 783227473ca1..21daa85357f1 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -309,7 +309,7 @@ def _fused_moe_lora_expand( b_intermediate_cache1 = torch.zeros( (num_slices, M, top_k_num, w1_output_dim_size), - dtype=torch.bfloat16, + dtype=output.dtype, device=device, ) @@ -426,7 +426,7 @@ def _fused_moe_lora( a_intermediate_cache1 = torch.zeros( (num_slices, M, top_k_num, max_lora_rank), - dtype=torch.bfloat16, + dtype=output.dtype, device=device, ) From 0d8fa61d76556eeb711a7992d8cce1e17a607c10 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Wed, 29 Oct 2025 02:50:50 +0000 Subject: [PATCH 19/36] clean code Signed-off-by: Yu Gong --- benchmarks/kernels/benchmark_lora.py | 193 ++------------------------- 1 file changed, 13 insertions(+), 180 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index e9cf3bb48206..2f3ade14be20 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -386,137 +386,6 @@ def bench_fn(self) -> Callable: raise ValueError(f"Unrecognized optype {self}") - def _run_fused_moe_lora_ref( - self, - output: torch.Tensor, - input: torch.Tensor, - lora_weights: list[torch.Tensor], - is_shrink: bool, - **kwargs, - ) -> None: - """ - Unified reference implementation for fused MoE LoRA operations. - - Processes tokens exactly as the kernel does: - - For each LoRA and block: get expert_id from expert_ids tensor - - For each token in block: get token_id from sorted_token_ids - - Perform the gemm with the corresponding expert's weights - """ - top_k_num = kwargs.get("top_k_num", 1) - seq_lens_cpu = kwargs.get("seq_lens_cpu") - prompt_lora_mapping_cpu = kwargs.get("prompt_lora_mapping_cpu") - scaling = kwargs.get("scaling", 1.0) - topk_weights = kwargs.get("topk_weights") - sorted_token_ids = kwargs.get("sorted_token_ids") # (num_loras, padded_size) - expert_ids = kwargs.get("expert_ids") # (num_loras, num_blocks) - mul_routed_weight = kwargs.get("mul_routed_weight", False) - w_dtype = lora_weights[0].dtype - num_slices = len(lora_weights) - - # Get block size from kernel config (needed to map tokens to blocks) - block_size_m = kwargs.get( - "shrink_block_size_m" if is_shrink else "expand_block_size_m", 64 - ) - - # Move to CPU for easier processing - sorted_token_ids_cpu = sorted_token_ids.cpu() - expert_ids_cpu = expert_ids.cpu() - num_loras = lora_weights[0].shape[0] - - # Process each LoRA - for lora_idx in range(num_loras): - # Find which batch uses this LoRA - batch_mask = prompt_lora_mapping_cpu == lora_idx - if not batch_mask.any(): - continue # No sequences use this LoRA - - # Process each slice - for slice_idx in range(num_slices): - weights_slice = lora_weights[ - slice_idx - ] # (num_loras, num_experts, out_dim, in_dim) - - # Process each block for this LoRA - num_blocks = expert_ids_cpu.shape[1] - for block_idx in range(num_blocks): - # Get the expert_id for this block - expert_id = expert_ids_cpu[lora_idx, block_idx].item() - if expert_id == -1: - continue # Empty block - - # Get weight for this expert and LoRA - weight = weights_slice[lora_idx, expert_id, :, :] - - # Process tokens in this block - block_start = block_idx * block_size_m - block_end = min( - block_start + block_size_m, sorted_token_ids_cpu.shape[1] - ) - - for token_pos in range(block_start, block_end): - sorted_token_id = sorted_token_ids_cpu[ - lora_idx, token_pos - ].item() - - # Check if this is a valid token (not padding) - num_tokens = seq_lens_cpu.sum().item() - if self == OpType.FUSED_MOE_LORA_DOWN_SHRINK and is_shrink: - max_valid = num_tokens * top_k_num - else: - max_valid = num_tokens - - if sorted_token_id >= max_valid: - continue # Padding token - - # Decode: original_token_idx and k_idx from sorted_token_id - if is_shrink: - # shrink: sorted_token_id (token_idx * top_k + k_idx) - original_token_idx = sorted_token_id // top_k_num - k_idx = sorted_token_id % top_k_num - else: - # For expand: sorted_token_id is just the token index - # k_idx from the expert routing (encoded in block structure) - original_token_idx = sorted_token_id - # Infer k_idx, in expand, tokens are organized differently - # Each expert processes tokens with that expert's k_idx - k_idx = expert_id % top_k_num # Simplified - - if is_shrink: - # Input index - if self == OpType.FUSED_MOE_LORA_DOWN_SHRINK: - input_idx = sorted_token_id - else: - input_idx = original_token_idx - - x = input[input_idx : input_idx + 1, :] - result = torch.nn.functional.linear(x, weight) - result = result * scaling - output[slice_idx, original_token_idx, k_idx, :] = ( - result.squeeze(0) - ) - - else: - # Expand - hidden_size = weights_slice.shape[2] - x = ( - input[slice_idx, original_token_idx, k_idx, :] - .unsqueeze(0) - .to(dtype=w_dtype) - ) - result = torch.nn.functional.linear(x, weight) - result = result * scaling - - if mul_routed_weight and topk_weights is not None: - route_weight = topk_weights[original_token_idx, k_idx] - result = result * route_weight - - slice_offset = slice_idx * hidden_size - output[ - original_token_idx, - k_idx, - slice_offset : slice_offset + hidden_size, - ] += result.squeeze(0) - def run_ref_group_gemm( self, output: torch.Tensor, @@ -549,24 +418,6 @@ def run_ref_group_gemm( lora_weights=lora_weights[slice_idx], **kwargs, ) - elif self.is_fused_moe_lora_shrink_fn(): - # For fused MoE LoRA shrink: input @ lora_a.T -> intermediate - # Input shape: (num_tokens, hidden_size) for gate_up or - # (num_tokens*top_k, hidden_size) for down - # Weight shape: (num_loras, num_experts, lora_rank, hidden_size) - # Output shape: (num_slices, num_tokens, top_k, lora_rank) - self._run_fused_moe_lora_ref( - output, input, lora_weights, is_shrink=True, **kwargs - ) - - elif self.is_fused_moe_lora_expand_fn(): - # For fused MoE LoRA expand: intermediate @ lora_b.T -> output - # Input shape: (num_slices, num_tokens, top_k, lora_rank) - # Weight shape: (num_loras, num_experts, hidden_size, lora_rank) - # Output shape: (num_tokens, top_k, hidden_size * num_slices) - self._run_fused_moe_lora_ref( - output, input, lora_weights, is_shrink=False, **kwargs - ) else: raise ValueError(f"Unrecognized optype {self}") @@ -1013,8 +864,8 @@ def as_fused_moe_lora_shrink_kwargs( "shrink_block_size_n": kernel_config["BLOCK_SIZE_N"], "shrink_block_size_k": kernel_config["BLOCK_SIZE_K"], "shrink_group_size_m": kernel_config["GROUP_SIZE_M"], - "shrink_num_warps": kernel_config["num_warps"], - "shrink_num_stages": kernel_config["num_stages"], + "shrink_num_warps": kernel_config["NUM_WARPS"], + "shrink_num_stages": kernel_config["NUM_STAGES"], "shrink_split_k": kernel_config.get("SPLIT_K", 1), "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), } @@ -1088,8 +939,8 @@ def as_fused_moe_lora_expand_kwargs( "expand_block_size_n": kernel_config["BLOCK_SIZE_N"], "expand_block_size_k": kernel_config["BLOCK_SIZE_K"], "expand_group_size_m": kernel_config["GROUP_SIZE_M"], - "expand_num_warps": kernel_config["num_warps"], - "expand_num_stages": kernel_config["num_stages"], + "expand_num_warps": kernel_config["NUM_WARPS"], + "expand_num_stages": kernel_config["NUM_STAGES"], "expand_split_k": kernel_config.get("SPLIT_K", 1), "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), } @@ -1113,7 +964,7 @@ def bench_fn_kwargs( raise ValueError(f"Unrecognized optype {self}") def test_correctness( - self, ctx: BenchmarkContext, op_type: OpType, expand_fn_add_inputs: bool | None + self, op_type: OpType, expand_fn_add_inputs: bool | None ) -> bool: """ Test correctness of op_type implementation against a grouped gemm @@ -1124,37 +975,16 @@ def test_correctness( ref_output = self.output.clone() self.output.zero_() - kernel_kwargs = self.bench_fn_kwargs(ctx, op_type, expand_fn_add_inputs) - op_type.bench_fn()(**kernel_kwargs) - - # Build reference kwargs - ref_kwargs = { - "seq_lens_cpu": seq_lens_cpu, - "prompt_lora_mapping_cpu": prompt_lora_mapping_cpu, - "scaling": 1.0, - "add_inputs": expand_fn_add_inputs, - } - - # Add fused_moe_lora specific kwargs if needed - if op_type.is_fused_moe_lora_fn(): - ref_kwargs.update( - { - "topk_weights": kernel_kwargs.get("topk_weights"), - "sorted_token_ids": kernel_kwargs.get("sorted_token_ids"), - "expert_ids": kernel_kwargs.get("expert_ids"), - "top_k_num": ctx.top_k_num, - "num_experts": ctx.num_experts, - "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), - "shrink_block_size_m": kernel_kwargs.get("shrink_block_size_m"), - "expand_block_size_m": kernel_kwargs.get("expand_block_size_m"), - } - ) + op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) op_type.run_ref_group_gemm( ref_output, self.input, self.lora_weights_lst, - **ref_kwargs, + seq_lens_cpu=seq_lens_cpu, + prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, + scaling=1.0, + add_inputs=expand_fn_add_inputs, ) rtol, atol = { @@ -1189,6 +1019,9 @@ def bench_optype( # Test correctness of our implementation. if test_correctness: + assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], ( + f"Correctness testing is not supported for {op_type.name}." + ) assert all( [ bt.test_correctness(ctx, op_type, expand_fn_add_inputs) From 3b1f04a72cd7736b089749efd31dfe39df87a802 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Fri, 31 Oct 2025 20:12:29 +0000 Subject: [PATCH 20/36] clean code Signed-off-by: Yu Gong --- benchmarks/kernels/benchmark_lora.py | 33 ++++++++++------------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 2f3ade14be20..56df1f57b1c0 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -278,7 +278,7 @@ def mkn( self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int ) -> tuple[int, int, int]: num_tokens = batch_size * seq_length - if self.is_shrink_fn(): + if self.is_shrink_fn() or self.is_fused_moe_lora_fn(): m = num_tokens k = hidden_size n = lora_rank @@ -286,11 +286,6 @@ def mkn( m = num_tokens k = lora_rank n = hidden_size - else: - assert self.is_fused_moe_lora_fn() - m = num_tokens - n = hidden_size - k = lora_rank return m, k, n def matmul_dtypes( @@ -493,6 +488,11 @@ def io_types(self) -> str: f"{dtype_to_str(self.output.dtype)}" ) + def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType): + return ( + size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size + ) + @staticmethod def make( ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" @@ -558,7 +558,6 @@ def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: """ Fails asserts when non-conformality is detected. """ - ##TODO test if this works num_tokens = ( self.input.shape[1] if op_type.is_fused_moe_lora_expand_fn() @@ -566,20 +565,15 @@ def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: ) # check metadata tensors ## In down shrink case, each token is repeated top_k_num times - assert ( - torch.sum(self.seq_lens) * ctx.top_k_num == num_tokens - if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] - else torch.sum(self.seq_lens) == num_tokens + assert num_tokens == self.get_num_tokens( + torch.sum(self.seq_lens), ctx.top_k_num, op_type ), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}" num_seqs = self.seq_lens.shape[0] # assert self.seq_start_loc.shape[0] == num_seqs ## In down shrink case, each prompt corresponds to top_k_num sequences assert self.prompt_lora_mapping.shape[0] == num_seqs - assert ( - self.lora_kernel_meta.token_lora_mapping.shape[0] * ctx.top_k_num - == num_tokens - if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] - else self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens + assert self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type ) def to_device(self, device: str): @@ -614,11 +608,8 @@ def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, in Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] - ## TODO: test if this works - num_tokens = ( - self.lora_kernel_meta.token_lora_mapping.shape[0] * ctx.top_k_num - if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] - else self.lora_kernel_meta.token_lora_mapping.shape[0] + num_tokens = self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type ) max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) From 3ad93dd8c22c619fbd224996f7a582b20125219a Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sat, 1 Nov 2025 18:31:17 +0000 Subject: [PATCH 21/36] clean code Signed-off-by: Yu Gong --- benchmarks/kernels/benchmark_lora.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 56df1f57b1c0..e91db8fba840 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -368,15 +368,15 @@ def bench_fn(self) -> Callable: return lora_shrink if self == OpType.LORA_EXPAND: return lora_expand - if ( - self == OpType.FUSED_MOE_LORA_GATE_UP_SHRINK - or self == OpType.FUSED_MOE_LORA_DOWN_SHRINK - ): + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ]: return fused_moe_lora_shrink - if ( - self == OpType.FUSED_MOE_LORA_GATE_UP_EXPAND - or self == OpType.FUSED_MOE_LORA_DOWN_EXPAND - ): + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ]: return fused_moe_lora_expand raise ValueError(f"Unrecognized optype {self}") @@ -1245,13 +1245,12 @@ def as_benchmark_contexts( num_active_loras=args.num_active_loras if args.num_active_loras else num_loras, - # To be filled based on the OpType to benchmark seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, - # To be filled based on the OpType to benchmark top_k_num=top_k_num, num_experts=num_experts, + # To be filled based on the OpType to benchmark num_slices=None, ) ) From 3f6357f0fc3fa688f2a625d6fe7c4745edb5e6b3 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sat, 1 Nov 2025 18:33:40 +0000 Subject: [PATCH 22/36] clean code Signed-off-by: Yu Gong --- benchmarks/kernels/benchmark_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index e91db8fba840..6715c9b548aa 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -1245,6 +1245,7 @@ def as_benchmark_contexts( num_active_loras=args.num_active_loras if args.num_active_loras else num_loras, + # To be filled based on the OpType to benchmark seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, From 3acf93b81394e5d4614ea8e8e42971b2f613db2b Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sat, 1 Nov 2025 19:11:04 +0000 Subject: [PATCH 23/36] restore pre-commit-config.yaml Signed-off-by: Yu Gong --- .pre-commit-config.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fbfd8016cb76..93e524475617 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,7 @@ repos: - id: format-torch-nightly-test name: reformat nightly_torch_test.txt to be in sync with test.in language: python - entry: python tools/generate_nightly_torch_test.py + entry: python tools/pre_commit/generate_nightly_torch_test.py files: ^requirements/test\.(in|txt)$ - id: mypy-local name: Run mypy locally for lowest supported Python version @@ -78,12 +78,12 @@ repos: stages: [manual] # Only run in CI - id: shellcheck name: Lint shell scripts - entry: tools/shellcheck.sh + entry: tools/pre_commit/shellcheck.sh language: script types: [shell] - id: png-lint name: Lint PNG exports from excalidraw - entry: tools/png-lint.sh + entry: tools/pre_commit/png-lint.sh language: script types: [png] - id: signoff-commit @@ -100,12 +100,12 @@ repos: stages: [commit-msg] - id: check-spdx-header name: Check SPDX headers - entry: python tools/check_spdx_header.py + entry: python tools/pre_commit/check_spdx_header.py language: python types: [python] - id: check-root-lazy-imports name: Check root lazy imports - entry: python tools/check_init_lazy_imports.py + entry: python tools/pre_commit/check_init_lazy_imports.py language: python types: [python] - id: check-filenames @@ -119,11 +119,11 @@ repos: pass_filenames: false - id: update-dockerfile-graph name: Update Dockerfile dependency graph - entry: tools/update-dockerfile-graph.sh + entry: tools/pre_commit/update-dockerfile-graph.sh language: script - id: enforce-import-regex-instead-of-re name: Enforce import regex as re - entry: python tools/enforce_regex_import.py + entry: python tools/pre_commit/enforce_regex_import.py language: python types: [python] pass_filenames: false @@ -131,7 +131,7 @@ repos: # forbid directly import triton - id: forbid-direct-triton-import name: "Forbid direct 'import triton'" - entry: python tools/check_triton_import.py + entry: python tools/pre_commit/check_triton_import.py language: python types: [python] pass_filenames: false @@ -144,7 +144,7 @@ repos: additional_dependencies: [regex] - id: validate-config name: Validate configuration has default values and that each field has a docstring - entry: python tools/validate_config.py + entry: python tools/pre_commit/validate_config.py language: python additional_dependencies: [regex] # Keep `suggestion` last @@ -154,4 +154,4 @@ repos: language: system verbose: true pass_filenames: false - # Insert new entries above the `suggestion` entry + # Insert new entries above the `suggestion` entry \ No newline at end of file From ff518b38caefe851aa1bc7b0eb024f1d986f53db Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sat, 1 Nov 2025 19:14:05 +0000 Subject: [PATCH 24/36] restore .pre-commit-config.yaml Signed-off-by: Yu Gong --- pre-commit-config.yaml | 157 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 pre-commit-config.yaml diff --git a/pre-commit-config.yaml b/pre-commit-config.yaml new file mode 100644 index 000000000000..bcd40e7f8ab3 --- /dev/null +++ b/pre-commit-config.yaml @@ -0,0 +1,157 @@ +default_install_hook_types: + - pre-commit + - commit-msg +default_stages: + - pre-commit # Run locally + - manual # Run in CI +exclude: 'vllm/third_party/.*' +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 + hooks: + - id: ruff-check + args: [--output-format, github, --fix] + - id: ruff-format +- repo: https://github.com/crate-ci/typos + rev: v1.38.1 + hooks: + - id: typos + args: [--force-exclude] +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v21.1.2 + hooks: + - id: clang-format + exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' + types_or: [c++, cuda] + args: [--style=file, --verbose] +- repo: https://github.com/igorshubovych/markdownlint-cli + rev: v0.45.0 + hooks: + - id: markdownlint + exclude: '.*\.inc\.md' + stages: [manual] # Only run in CI +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint +- repo: https://github.com/astral-sh/uv-pre-commit + rev: 0.9.1 + hooks: + - id: pip-compile + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28] + files: ^requirements/test\.(in|txt)$ +- repo: local + hooks: + - id: format-torch-nightly-test + name: reformat nightly_torch_test.txt to be in sync with test.in + language: python + entry: python tools/pre_commit/generate_nightly_torch_test.py + files: ^requirements/test\.(in|txt)$ + - id: mypy-local + name: Run mypy locally for lowest supported Python version + entry: python tools/pre_commit/mypy.py 0 "3.10" + stages: [pre-commit] # Don't run in CI + <<: &mypy_common + language: python + types_or: [python, pyi] + require_serial: true + additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] + - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.10 + entry: python tools/pre_commit/mypy.py 1 "3.10" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.11 + entry: python tools/pre_commit/mypy.py 1 "3.11" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.12 + entry: python tools/pre_commit/mypy.py 1 "3.12" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward + name: Run mypy for Python 3.13 + entry: python tools/pre_commit/mypy.py 1 "3.13" + <<: *mypy_common + stages: [manual] # Only run in CI + - id: shellcheck + name: Lint shell scripts + entry: tools/pre_commit/shellcheck.sh + language: script + types: [shell] + - id: png-lint + name: Lint PNG exports from excalidraw + entry: tools/pre_commit/png-lint.sh + language: script + types: [png] + - id: signoff-commit + name: Sign-off Commit + entry: bash + args: + - -c + - | + if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then + printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)" + fi + language: system + verbose: true + stages: [commit-msg] + - id: check-spdx-header + name: Check SPDX headers + entry: python tools/pre_commit/check_spdx_header.py + language: python + types: [python] + - id: check-root-lazy-imports + name: Check root lazy imports + entry: python tools/pre_commit/check_init_lazy_imports.py + language: python + types: [python] + - id: check-filenames + name: Check for spaces in all filenames + entry: bash + args: + - -c + - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' + language: system + always_run: true + pass_filenames: false + - id: update-dockerfile-graph + name: Update Dockerfile dependency graph + entry: tools/pre_commit/update-dockerfile-graph.sh + language: script + - id: enforce-import-regex-instead-of-re + name: Enforce import regex as re + entry: python tools/pre_commit/enforce_regex_import.py + language: python + types: [python] + pass_filenames: false + additional_dependencies: [regex] + # forbid directly import triton + - id: forbid-direct-triton-import + name: "Forbid direct 'import triton'" + entry: python tools/pre_commit/check_triton_import.py + language: python + types: [python] + pass_filenames: false + additional_dependencies: [regex] + - id: check-pickle-imports + name: Prevent new pickle/cloudpickle imports + entry: python tools/pre_commit/check_pickle_imports.py + language: python + types: [python] + additional_dependencies: [regex] + - id: validate-config + name: Validate configuration has default values and that each field has a docstring + entry: python tools/pre_commit/validate_config.py + language: python + additional_dependencies: [regex] + # Keep `suggestion` last + - id: suggestion + name: Suggestion + entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=."' + language: system + verbose: true + pass_filenames: false + # Insert new entries above the `suggestion` entry From dfb9dd175087d5605f2fff2aacfcf5990742c068 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sat, 1 Nov 2025 19:15:31 +0000 Subject: [PATCH 25/36] clean code Signed-off-by: Yu Gong --- .pre-commit-config.yaml | 3 +- pre-commit-config.yaml | 157 ---------------------------------------- 2 files changed, 2 insertions(+), 158 deletions(-) delete mode 100644 pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 93e524475617..7c06411bd003 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -154,4 +154,5 @@ repos: language: system verbose: true pass_filenames: false - # Insert new entries above the `suggestion` entry \ No newline at end of file + # Insert new entries above the `suggestion` entry + \ No newline at end of file diff --git a/pre-commit-config.yaml b/pre-commit-config.yaml deleted file mode 100644 index bcd40e7f8ab3..000000000000 --- a/pre-commit-config.yaml +++ /dev/null @@ -1,157 +0,0 @@ -default_install_hook_types: - - pre-commit - - commit-msg -default_stages: - - pre-commit # Run locally - - manual # Run in CI -exclude: 'vllm/third_party/.*' -repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.0 - hooks: - - id: ruff-check - args: [--output-format, github, --fix] - - id: ruff-format -- repo: https://github.com/crate-ci/typos - rev: v1.38.1 - hooks: - - id: typos - args: [--force-exclude] -- repo: https://github.com/pre-commit/mirrors-clang-format - rev: v21.1.2 - hooks: - - id: clang-format - exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' - types_or: [c++, cuda] - args: [--style=file, --verbose] -- repo: https://github.com/igorshubovych/markdownlint-cli - rev: v0.45.0 - hooks: - - id: markdownlint - exclude: '.*\.inc\.md' - stages: [manual] # Only run in CI -- repo: https://github.com/rhysd/actionlint - rev: v1.7.7 - hooks: - - id: actionlint -- repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.9.1 - hooks: - - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu129, --python-platform, x86_64-manylinux_2_28] - files: ^requirements/test\.(in|txt)$ -- repo: local - hooks: - - id: format-torch-nightly-test - name: reformat nightly_torch_test.txt to be in sync with test.in - language: python - entry: python tools/pre_commit/generate_nightly_torch_test.py - files: ^requirements/test\.(in|txt)$ - - id: mypy-local - name: Run mypy locally for lowest supported Python version - entry: python tools/pre_commit/mypy.py 0 "3.10" - stages: [pre-commit] # Don't run in CI - <<: &mypy_common - language: python - types_or: [python, pyi] - require_serial: true - additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.10 - entry: python tools/pre_commit/mypy.py 1 "3.10" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.11 - entry: python tools/pre_commit/mypy.py 1 "3.11" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.12 - entry: python tools/pre_commit/mypy.py 1 "3.12" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - name: Run mypy for Python 3.13 - entry: python tools/pre_commit/mypy.py 1 "3.13" - <<: *mypy_common - stages: [manual] # Only run in CI - - id: shellcheck - name: Lint shell scripts - entry: tools/pre_commit/shellcheck.sh - language: script - types: [shell] - - id: png-lint - name: Lint PNG exports from excalidraw - entry: tools/pre_commit/png-lint.sh - language: script - types: [png] - - id: signoff-commit - name: Sign-off Commit - entry: bash - args: - - -c - - | - if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then - printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)" - fi - language: system - verbose: true - stages: [commit-msg] - - id: check-spdx-header - name: Check SPDX headers - entry: python tools/pre_commit/check_spdx_header.py - language: python - types: [python] - - id: check-root-lazy-imports - name: Check root lazy imports - entry: python tools/pre_commit/check_init_lazy_imports.py - language: python - types: [python] - - id: check-filenames - name: Check for spaces in all filenames - entry: bash - args: - - -c - - 'git ls-files | grep " " && echo "Filenames should not contain spaces!" && exit 1 || exit 0' - language: system - always_run: true - pass_filenames: false - - id: update-dockerfile-graph - name: Update Dockerfile dependency graph - entry: tools/pre_commit/update-dockerfile-graph.sh - language: script - - id: enforce-import-regex-instead-of-re - name: Enforce import regex as re - entry: python tools/pre_commit/enforce_regex_import.py - language: python - types: [python] - pass_filenames: false - additional_dependencies: [regex] - # forbid directly import triton - - id: forbid-direct-triton-import - name: "Forbid direct 'import triton'" - entry: python tools/pre_commit/check_triton_import.py - language: python - types: [python] - pass_filenames: false - additional_dependencies: [regex] - - id: check-pickle-imports - name: Prevent new pickle/cloudpickle imports - entry: python tools/pre_commit/check_pickle_imports.py - language: python - types: [python] - additional_dependencies: [regex] - - id: validate-config - name: Validate configuration has default values and that each field has a docstring - entry: python tools/pre_commit/validate_config.py - language: python - additional_dependencies: [regex] - # Keep `suggestion` last - - id: suggestion - name: Suggestion - entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=."' - language: system - verbose: true - pass_filenames: false - # Insert new entries above the `suggestion` entry From d950b3dc79735e1ad1115afdea47819442783c84 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sat, 1 Nov 2025 19:17:35 +0000 Subject: [PATCH 26/36] clean code Signed-off-by: Yu Gong --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7c06411bd003..bcd40e7f8ab3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -155,4 +155,3 @@ repos: verbose: true pass_filenames: false # Insert new entries above the `suggestion` entry - \ No newline at end of file From 65c11e90742fd9458a34f972a46154f54fae2edc Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Sun, 2 Nov 2025 18:22:57 +0000 Subject: [PATCH 27/36] clean code Signed-off-by: Yu Gong --- vllm/lora/ops/triton_ops/README_TUNING.md | 3 +- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 86 +++++++++---------- vllm/lora/ops/triton_ops/utils.py | 10 +-- 3 files changed, 50 insertions(+), 49 deletions(-) diff --git a/vllm/lora/ops/triton_ops/README_TUNING.md b/vllm/lora/ops/triton_ops/README_TUNING.md index 56c8c27051cc..c04e0d89bedf 100644 --- a/vllm/lora/ops/triton_ops/README_TUNING.md +++ b/vllm/lora/ops/triton_ops/README_TUNING.md @@ -56,4 +56,5 @@ The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_n ### Json Structure -Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][n2]` +Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]` +where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer. diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 21daa85357f1..41b5200a67d4 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -199,31 +199,31 @@ def _fused_moe_lora_shrink( num_tokens: int, num_experts: int, num_slices: int, - shrink_block_size_m: int, - shrink_block_size_n: int, - shrink_block_size_k: int, - shrink_group_size_m: int, - shrink_num_warps: int, - shrink_num_stages: int, - shrink_split_k: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, mul_routed_weight: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] shrink_config = { - "BLOCK_SIZE_M": shrink_block_size_m, - "BLOCK_SIZE_N": shrink_block_size_n, - "BLOCK_SIZE_K": shrink_block_size_k, - "GROUP_SIZE_M": shrink_group_size_m, - "num_warps": shrink_num_warps, - "num_stages": shrink_num_stages, - "SPLIT_K": shrink_split_k, + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, } b_ptr = _get_ptr(lora_a_stacked, device) grid = lambda META: ( - shrink_split_k + split_k * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_a_stacked), @@ -288,13 +288,13 @@ def _fused_moe_lora_expand( num_slices: int, max_lora_rank: int, w1_output_dim_size: int, - expand_block_size_m: int, - expand_block_size_n: int, - expand_block_size_k: int, - expand_group_size_m: int, - expand_num_warps: int, - expand_num_stages: int, - expand_split_k: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, mul_routed_weight: bool = False, ) -> None: b_ptr = _get_ptr(lora_b_stacked, device) @@ -314,13 +314,13 @@ def _fused_moe_lora_expand( ) expand_config = { - "BLOCK_SIZE_M": expand_block_size_m, - "BLOCK_SIZE_N": expand_block_size_n, - "BLOCK_SIZE_K": expand_block_size_k, - "GROUP_SIZE_M": expand_group_size_m, - "num_warps": expand_num_warps, - "num_stages": expand_num_stages, - "SPLIT_K": expand_split_k, # Set split_k = 1 for expand calls + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, # Set split_k = 1 for expand calls } grid = lambda META: ( @@ -538,13 +538,13 @@ def _fused_moe_lora_shrink_fake( num_tokens: int, num_experts: int, num_slices: int, - shrink_block_size_m: int, - shrink_block_size_n: int, - shrink_block_size_k: int, - shrink_group_size_m: int, - shrink_num_warps: int, - shrink_num_stages: int, - shrink_split_k: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, mul_routed_weight: bool = False, ) -> None: return @@ -569,13 +569,13 @@ def _fused_moe_lora_expand_fake( num_slices: int, max_lora_rank: int, w1_output_dim_size: int, - expand_block_size_m: int, - expand_block_size_n: int, - expand_block_size_k: int, - expand_group_size_m: int, - expand_num_warps: int, - expand_num_stages: int, - expand_split_k: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, mul_routed_weight: bool = False, ) -> None: return diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index a7c45d79d81e..258848d35bbf 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -186,7 +186,7 @@ def get_lora_op_configs( rank: int, num_slices: int, add_inputs: bool | None = None, - hidden_size_2: int | None = None, + moe_intermediate_size: int | None = None, ) -> dict[str, int | None]: # Add support for fused_moe_lora ops assert op_type in [ @@ -272,11 +272,11 @@ def get_lora_op_configs( ) # slice by hidden_size_2 - if hidden_size_2 is not None: - n2 = hidden_size_2 + if moe_intermediate_size is not None: + i = moe_intermediate_size config_data = ( - config_data.get(str(n2)) - or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n2))] + config_data.get(str(i)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))] ) assert config_data is not None From f451ca784fc537701c9b2397f7cec36fd231ab46 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Mon, 3 Nov 2025 21:36:20 +0000 Subject: [PATCH 28/36] rename the config Signed-off-by: Yu Gong --- vllm/lora/layers/fused_moe.py | 4 +++- vllm/lora/ops/triton_ops/README_TUNING.md | 8 ++++---- vllm/lora/ops/triton_ops/utils.py | 24 +++++++++++------------ 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index bcc7c76d7927..af27eef6e25d 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -148,7 +148,9 @@ def wrapper(*args, **kwargs): # get the block size of m from customized config or default config max_loras = self.w1_lora_a_stacked.shape[0] - block_size = shrink_config.get("BLOCK_SIZE_M", 64) + block_size = shrink_config.get("BLOCK_SIZE_M") or shrink_config.get( + "block_m", 64 + ) ( sorted_token_ids_lora, expert_ids_lora, diff --git a/vllm/lora/ops/triton_ops/README_TUNING.md b/vllm/lora/ops/triton_ops/README_TUNING.md index c04e0d89bedf..d576e261557a 100644 --- a/vllm/lora/ops/triton_ops/README_TUNING.md +++ b/vllm/lora/ops/triton_ops/README_TUNING.md @@ -44,13 +44,13 @@ For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`. -For `fused_moe_lora_gate_up_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_GATE_UP_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_GATE_UP_SHRINK.json`. +For `fused_moe_lora_w13_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json`. -For `fused_moe_lora_gate_up_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_GATE_UP_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_GATE_UP_EXPAND.json`. +For `fused_moe_lora_w13_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json`. -For `fused_moe_lora_down_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_DOWN_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_DOWN_SHRINK.json`. +For `fused_moe_lora_w2_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json`. -For `fused_moe_lora_down_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_DOWN_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_DOWN_EXPAND.json`. +For `fused_moe_lora_w2_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json`. The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()` diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 258848d35bbf..e114fa736784 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -192,10 +192,10 @@ def get_lora_op_configs( assert op_type in [ "shrink", "expand", - "fused_moe_lora_gate_up_shrink", - "fused_moe_lora_gate_up_expand", - "fused_moe_lora_down_shrink", - "fused_moe_lora_down_expand", + "fused_moe_lora_w13_shrink", + "fused_moe_lora_w13_expand", + "fused_moe_lora_w2_shrink", + "fused_moe_lora_w2_expand", ] # default config @@ -219,13 +219,13 @@ def get_lora_op_configs( "fused_moe_lora_down_expand", ]: default = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "NUM_WARPS": 4, - "NUM_STAGES": 3, - "GROUP_SIZE_M": 8, - "SPLIT_K": 1, + "block_m": 64, + "block_n": 64, + "block_k": 32, + "num_warps": 4, + "num_stages": 3, + "group_size_m": 8, + "split_k": 1, } else: default = { @@ -271,7 +271,7 @@ def get_lora_op_configs( or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))] ) - # slice by hidden_size_2 + # slice by moe-intermediate-size if applicable if moe_intermediate_size is not None: i = moe_intermediate_size config_data = ( From 22faf7e8630c66fd45ce3e4c6e0fd2a77f7bcd60 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Mon, 3 Nov 2025 21:40:01 +0000 Subject: [PATCH 29/36] clean code Signed-off-by: Yu Gong --- vllm/lora/layers/fused_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index af27eef6e25d..7a6bcd391273 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -136,7 +136,7 @@ def wrapper(*args, **kwargs): M = min(num_tokens, CHUNK_SIZE) shrink_config, expand_config = self._get_lora_moe_configs( - op_prefix="gate_up", + op_prefix="w13", lora_a_stacked=self.w1_lora_a_stacked, lora_b_stacked=self.w1_lora_b_stacked, num_slices=2, @@ -216,7 +216,7 @@ def wrapper(*args, **kwargs): M = min(num_tokens, CHUNK_SIZE) shrink_config, expand_config = self._get_lora_moe_configs( - op_prefix="down", + op_prefix="w2", lora_a_stacked=self.w2_lora_a_stacked, lora_b_stacked=self.w2_lora_b_stacked, num_slices=1, From 0b439f74a38fdce2998fd20967064ddd9620f5c7 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Mon, 3 Nov 2025 21:42:18 +0000 Subject: [PATCH 30/36] fix format issue Signed-off-by: Yu Gong --- vllm/lora/punica_wrapper/punica_gpu.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 295a10c7083d..0907f81238c7 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -389,19 +389,19 @@ def add_lora_fused_moe( top_k_num, lora_ids, adapter_enabled, - shrink_config.get("BLOCK_SIZE_M", 64), - shrink_config.get("BLOCK_SIZE_N", 64), - shrink_config.get("BLOCK_SIZE_K", 32), - shrink_config.get("GROUP_SIZE_M", 8), - shrink_config.get("NUM_WARPS", 4), - shrink_config.get("NUM_STAGES", 3), - shrink_config.get("SPLIT_K", 1), - expand_config.get("BLOCK_SIZE_M", 64), - expand_config.get("BLOCK_SIZE_N", 64), - expand_config.get("BLOCK_SIZE_K", 64), - expand_config.get("GROUP_SIZE_M", 64), - expand_config.get("NUM_WARPS", 4), - expand_config.get("NUM_STAGES", 3), - expand_config.get("SPLIT_K", 1), + shrink_config.get("BLOCK_SIZE_M") or shrink_config.get("block_m") or 64, + shrink_config.get("BLOCK_SIZE_N") or shrink_config.get("block_n") or 64, + shrink_config.get("BLOCK_SIZE_K") or shrink_config.get("block_k") or 32, + shrink_config.get("GROUP_SIZE_M") or shrink_config.get("group_m") or 8, + shrink_config.get("NUM_WARPS") or shrink_config.get("num_warps") or 4, + shrink_config.get("NUM_STAGES") or shrink_config.get("num_stages") or 3, + shrink_config.get("SPLIT_K") or shrink_config.get("split_k") or 1, + expand_config.get("BLOCK_SIZE_M") or expand_config.get("block_m") or 64, + expand_config.get("BLOCK_SIZE_N") or expand_config.get("block_n") or 64, + expand_config.get("BLOCK_SIZE_K") or expand_config.get("block_k") or 64, + expand_config.get("GROUP_SIZE_M") or expand_config.get("group_m") or 64, + expand_config.get("NUM_WARPS") or expand_config.get("num_warps") or 4, + expand_config.get("NUM_STAGES") or expand_config.get("num_stages") or 3, + expand_config.get("SPLIT_K") or expand_config.get("split_k") or 1, mul_routed_weight, ) From a1ec116d70d6841cb7fca6e4e119dd448fdf5235 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Mon, 3 Nov 2025 21:58:38 +0000 Subject: [PATCH 31/36] Rabase PR Signed-off-by: Yu Gong --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 41b5200a67d4..8f85f926aa4f 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -277,6 +277,8 @@ def _fused_moe_lora_expand( expert_ids: torch.Tensor, # (max_loras, _ ,) num_tokens_post_padded: torch.Tensor, # (max_loras, ) top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, ## adding for kernel device: torch.device, N: int, @@ -381,6 +383,8 @@ def _fused_moe_lora( num_tokens_post_padded: torch.Tensor, # (max_loras, ) max_lora_rank: int, top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, shrink_block_size_k: int, @@ -439,6 +443,8 @@ def _fused_moe_lora( expert_ids, num_tokens_post_padded, top_k_num, + lora_ids, + adapter_enabled, ## adding for kernel device, N, @@ -467,6 +473,8 @@ def _fused_moe_lora( expert_ids, num_tokens_post_padded, top_k_num, + lora_ids, + adapter_enabled, ## adding for kernel device, N, @@ -530,6 +538,8 @@ def _fused_moe_lora_shrink_fake( expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, device: torch.device, N: int, M: int, @@ -559,6 +569,8 @@ def _fused_moe_lora_expand_fake( expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, device: torch.device, N: int, M: int, From d73f410fce9e6815f928f84e66d68aa841d1b6bc Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Mon, 3 Nov 2025 22:04:21 +0000 Subject: [PATCH 32/36] Renaming kernel Signed-off-by: Yu Gong --- vllm/lora/layers/fused_moe.py | 4 ++-- vllm/lora/ops/triton_ops/utils.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 7a6bcd391273..3983af172b5b 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -59,7 +59,7 @@ def _get_lora_moe_configs( hidden_size=lora_a_stacked.shape[-1], rank=lora_a_stacked.shape[-2], num_slices=num_slices, - hidden_size_2=lora_b_stacked.shape[-2], + moe_intermediate_size=lora_b_stacked.shape[-2], ) expand_config = get_lora_op_configs( op_type=f"fused_moe_lora_{op_prefix}_expand", @@ -68,7 +68,7 @@ def _get_lora_moe_configs( hidden_size=lora_a_stacked.shape[-1], rank=lora_a_stacked.shape[-2], num_slices=num_slices, - hidden_size_2=lora_b_stacked.shape[-2], + moe_intermediate_size=lora_b_stacked.shape[-2], ) else: # fall back to the default config get_config_func = functools.partial( diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index e114fa736784..120c508b3f60 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -213,10 +213,10 @@ def get_lora_op_configs( } # The default config for fused_moe_lora ops elif op_type in [ - "fused_moe_lora_gate_up_shrink", - "fused_moe_lora_gate_up_expand", - "fused_moe_lora_down_shrink", - "fused_moe_lora_down_expand", + "fused_moe_lora_w13_shrink", + "fused_moe_lora_w13_expand", + "fused_moe_lora_w2_shrink", + "fused_moe_lora_w2_expand", ]: default = { "block_m": 64, From 51f00b214cbd3abd219d016781d8cc6bf6cea24c Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Mon, 3 Nov 2025 22:05:00 +0000 Subject: [PATCH 33/36] renaming Signed-off-by: Yu Gong --- vllm/lora/punica_wrapper/punica_gpu.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 0907f81238c7..587e0b24a912 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -389,19 +389,19 @@ def add_lora_fused_moe( top_k_num, lora_ids, adapter_enabled, - shrink_config.get("BLOCK_SIZE_M") or shrink_config.get("block_m") or 64, - shrink_config.get("BLOCK_SIZE_N") or shrink_config.get("block_n") or 64, - shrink_config.get("BLOCK_SIZE_K") or shrink_config.get("block_k") or 32, - shrink_config.get("GROUP_SIZE_M") or shrink_config.get("group_m") or 8, - shrink_config.get("NUM_WARPS") or shrink_config.get("num_warps") or 4, - shrink_config.get("NUM_STAGES") or shrink_config.get("num_stages") or 3, - shrink_config.get("SPLIT_K") or shrink_config.get("split_k") or 1, - expand_config.get("BLOCK_SIZE_M") or expand_config.get("block_m") or 64, - expand_config.get("BLOCK_SIZE_N") or expand_config.get("block_n") or 64, - expand_config.get("BLOCK_SIZE_K") or expand_config.get("block_k") or 64, - expand_config.get("GROUP_SIZE_M") or expand_config.get("group_m") or 64, - expand_config.get("NUM_WARPS") or expand_config.get("num_warps") or 4, - expand_config.get("NUM_STAGES") or expand_config.get("num_stages") or 3, - expand_config.get("SPLIT_K") or expand_config.get("split_k") or 1, + shrink_config.get("BLOCK_SIZE_M") or shrink_config.get("block_m", 64), + shrink_config.get("BLOCK_SIZE_N") or shrink_config.get("block_n", 64), + shrink_config.get("BLOCK_SIZE_K") or shrink_config.get("block_k", 32), + shrink_config.get("GROUP_SIZE_M") or shrink_config.get("group_m", 8), + shrink_config.get("NUM_WARPS") or shrink_config.get("num_warps", 4), + shrink_config.get("NUM_STAGES") or shrink_config.get("num_stages", 3), + shrink_config.get("SPLIT_K") or shrink_config.get("split_k", 1), + expand_config.get("BLOCK_SIZE_M") or expand_config.get("block_m", 64), + expand_config.get("BLOCK_SIZE_N") or expand_config.get("block_n", 64), + expand_config.get("BLOCK_SIZE_K") or expand_config.get("block_k", 32), + expand_config.get("GROUP_SIZE_M") or expand_config.get("group_m", 8), + expand_config.get("NUM_WARPS") or expand_config.get("num_warps", 4), + expand_config.get("NUM_STAGES") or expand_config.get("num_stages", 3), + expand_config.get("SPLIT_K") or expand_config.get("split_k", 1), mul_routed_weight, ) From 50afb568822100fe78b7da417451003f513af52f Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 4 Nov 2025 04:14:19 +0000 Subject: [PATCH 34/36] Normalize key name as uppercase Signed-off-by: Yu Gong --- vllm/lora/layers/fused_moe.py | 19 ++++++++++++----- vllm/lora/punica_wrapper/punica_gpu.py | 28 +++++++++++++------------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 3983af172b5b..467e7dcc9c1a 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -40,6 +40,17 @@ def __init__(self, base_layer: FusedMoE) -> None: self.device = base_layer.w2_weight.device self._inject_lora_into_fused_moe() + def _normalize_keys(self, config: dict[str, int]) -> dict[str, int]: + normalized_config = {} + for key, value in config.items(): + if key.islower(): + if key.startswith("block_"): + normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper() + else: + normalized_key = key.upper() + normalized_config[normalized_key] = value + return normalized_config + def _get_lora_moe_configs( self, op_prefix: str, @@ -81,7 +92,8 @@ def _get_lora_moe_configs( ) shrink_config = get_config_func(M) expand_config = get_config_func(M) - + shrink_config = self._normalize_keys(shrink_config) + expand_config = self._normalize_keys(expand_config) return shrink_config, expand_config def _inject_lora_into_fused_moe(self): @@ -148,9 +160,6 @@ def wrapper(*args, **kwargs): # get the block size of m from customized config or default config max_loras = self.w1_lora_a_stacked.shape[0] - block_size = shrink_config.get("BLOCK_SIZE_M") or shrink_config.get( - "block_m", 64 - ) ( sorted_token_ids_lora, expert_ids_lora, @@ -158,7 +167,7 @@ def wrapper(*args, **kwargs): ) = self.punica_wrapper.moe_lora_align_block_size( curr_topk_ids, num_tokens, - block_size, + shrink_config["BLOCK_SIZE_M"], self.base_layer.local_num_experts, max_loras, self.adapter_enabled, diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 587e0b24a912..1bb80e516d3f 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -389,19 +389,19 @@ def add_lora_fused_moe( top_k_num, lora_ids, adapter_enabled, - shrink_config.get("BLOCK_SIZE_M") or shrink_config.get("block_m", 64), - shrink_config.get("BLOCK_SIZE_N") or shrink_config.get("block_n", 64), - shrink_config.get("BLOCK_SIZE_K") or shrink_config.get("block_k", 32), - shrink_config.get("GROUP_SIZE_M") or shrink_config.get("group_m", 8), - shrink_config.get("NUM_WARPS") or shrink_config.get("num_warps", 4), - shrink_config.get("NUM_STAGES") or shrink_config.get("num_stages", 3), - shrink_config.get("SPLIT_K") or shrink_config.get("split_k", 1), - expand_config.get("BLOCK_SIZE_M") or expand_config.get("block_m", 64), - expand_config.get("BLOCK_SIZE_N") or expand_config.get("block_n", 64), - expand_config.get("BLOCK_SIZE_K") or expand_config.get("block_k", 32), - expand_config.get("GROUP_SIZE_M") or expand_config.get("group_m", 8), - expand_config.get("NUM_WARPS") or expand_config.get("num_warps", 4), - expand_config.get("NUM_STAGES") or expand_config.get("num_stages", 3), - expand_config.get("SPLIT_K") or expand_config.get("split_k", 1), + shrink_config.get("BLOCK_SIZE_M", 64), + shrink_config.get("BLOCK_SIZE_N", 64), + shrink_config.get("BLOCK_SIZE_K", 32), + shrink_config.get("GROUP_SIZE_M", 8), + shrink_config.get("NUM_WARPS", 4), + shrink_config.get("NUM_STAGES", 3), + shrink_config.get("SPLIT_K", 1), + expand_config.get("BLOCK_SIZE_M", 64), + expand_config.get("BLOCK_SIZE_N", 64), + expand_config.get("BLOCK_SIZE_K", 32), + expand_config.get("GROUP_SIZE_M", 8), + expand_config.get("NUM_WARPS", 4), + expand_config.get("NUM_STAGES", 3), + expand_config.get("SPLIT_K", 1), mul_routed_weight, ) From 221b287237f83e96cd67f0c8cb1cc3c317cc1370 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 4 Nov 2025 04:33:32 +0000 Subject: [PATCH 35/36] fix bugs Signed-off-by: Yu Gong --- .pre-commit-config.yaml | 4 ++-- vllm/lora/layers/fused_moe.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bcd40e7f8ab3..f3049e9a94d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_install_hook_types: - pre-commit - commit-msg default_stages: - - pre-commit # Run locally + - commit # Run locally - manual # Run in CI exclude: 'vllm/third_party/.*' repos: @@ -50,7 +50,7 @@ repos: - id: mypy-local name: Run mypy locally for lowest supported Python version entry: python tools/pre_commit/mypy.py 0 "3.10" - stages: [pre-commit] # Don't run in CI + stages: [commit] # Don't run in CI <<: &mypy_common language: python types_or: [python, pyi] diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 467e7dcc9c1a..f5a766dd5e45 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -40,7 +40,7 @@ def __init__(self, base_layer: FusedMoE) -> None: self.device = base_layer.w2_weight.device self._inject_lora_into_fused_moe() - def _normalize_keys(self, config: dict[str, int]) -> dict[str, int]: + def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: normalized_config = {} for key, value in config.items(): if key.islower(): @@ -48,6 +48,8 @@ def _normalize_keys(self, config: dict[str, int]) -> dict[str, int]: normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper() else: normalized_key = key.upper() + else: + normalized_key = key normalized_config[normalized_key] = value return normalized_config From 1542c93db63ed4ae3b7ef1a2b52af29191715613 Mon Sep 17 00:00:00 2001 From: Yu Gong Date: Tue, 4 Nov 2025 04:35:31 +0000 Subject: [PATCH 36/36] fix bugs Signed-off-by: Yu Gong --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f3049e9a94d4..bcd40e7f8ab3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ default_install_hook_types: - pre-commit - commit-msg default_stages: - - commit # Run locally + - pre-commit # Run locally - manual # Run in CI exclude: 'vllm/third_party/.*' repos: @@ -50,7 +50,7 @@ repos: - id: mypy-local name: Run mypy locally for lowest supported Python version entry: python tools/pre_commit/mypy.py 0 "3.10" - stages: [commit] # Don't run in CI + stages: [pre-commit] # Don't run in CI <<: &mypy_common language: python types_or: [python, pyi]