diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index bf1512268fe0..6715c9b548aa 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -19,13 +19,24 @@ from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.triton_utils import HAS_TRITON +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs +from vllm.triton_utils import HAS_TRITON, triton if HAS_TRITON: - from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora + LoRAKernelMeta, + fused_moe_lora_expand, + fused_moe_lora_shrink, + lora_expand, + lora_shrink, + ) + from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + _LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora + ) from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT - +from vllm import _custom_ops as ops from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.math_utils import round_up DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] @@ -59,6 +70,8 @@ DEFAULT_SORT_BY_LORA_IDS = [False, True] DEFAULT_SEQ_LENGTHS = [1] DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False] +DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k +DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts # Utilities @@ -191,6 +204,11 @@ class OpType(Enum): LORA_SHRINK = auto() LORA_EXPAND = auto() + ## Adding support for fused moe lora + FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink + FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand + FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink + FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand @staticmethod def from_str(s: str) -> "OpType": @@ -198,6 +216,15 @@ def from_str(s: str) -> "OpType": return OpType.LORA_SHRINK if s.lower() == "lora_expand": return OpType.LORA_EXPAND + # Adding support for fused moe lora, both in gate_up and down + if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink + return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK + if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand + return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND + if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink + return OpType.FUSED_MOE_LORA_DOWN_SHRINK + if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand + return OpType.FUSED_MOE_LORA_DOWN_EXPAND raise ValueError(f"Unrecognized str {s} to convert to OpType") def is_shrink_fn(self) -> bool: @@ -206,19 +233,56 @@ def is_shrink_fn(self) -> bool: def is_expand_fn(self) -> bool: return self in [OpType.LORA_EXPAND] + def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_gate_up_fn( + self, + ) -> bool: ## adding for fused MoE LoRA Gate/Up + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + ] + + def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down + return self in [ + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + + def is_fused_moe_lora_shrink_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ] + + def is_fused_moe_lora_expand_fn(self) -> bool: + return self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ] + def num_slices(self) -> list[int]: + if self.is_fused_moe_lora_gate_up_fn(): + return [2] + elif self.is_fused_moe_lora_down_fn(): + return [1] return [1, 2, 3] def mkn( self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int ) -> tuple[int, int, int]: num_tokens = batch_size * seq_length - if self.is_shrink_fn(): + if self.is_shrink_fn() or self.is_fused_moe_lora_fn(): m = num_tokens k = hidden_size n = lora_rank - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): m = num_tokens k = lora_rank n = hidden_size @@ -232,9 +296,36 @@ def matmul_dtypes( """ if self.is_shrink_fn(): return op_dtype, op_dtype, torch.float32 - else: - assert self.is_expand_fn() + elif self.is_expand_fn(): return torch.float32, op_dtype, op_dtype + else: + assert self.is_fused_moe_lora_fn() + return op_dtype, op_dtype, op_dtype + + def matmul_shapes_fused_moe_lora( + self, + m: int, + n: int, + k: int, + num_loras: int, + num_slices: int, + top_k_num: int, + num_experts: int, + ) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]: + if self.is_fused_moe_lora_shrink_fn(): + input_shape = ( + (m * top_k_num, n) + if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else (m, n) + ) + output_shape = (num_slices, m, top_k_num, k) + weight_shape = (num_loras, num_experts, k, n) + else: + assert self.is_fused_moe_lora_expand_fn() + input_shape = (num_slices, m, top_k_num, k) + output_shape = (m, top_k_num, n * num_slices) + weight_shape = (num_loras, num_experts, n, k) + return (input_shape, weight_shape, output_shape) def matmul_shapes( self, @@ -244,6 +335,8 @@ def matmul_shapes( lora_rank: int, num_loras: int, num_slices: int, + top_k_num: int | None = None, + num_experts: int | None = None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Given num_slices, return the shapes of the A, B, and C matrices @@ -258,6 +351,16 @@ def matmul_shapes( if self in [OpType.LORA_EXPAND]: # LoRA expand kernels support num_slices inherently in the kernel return ((num_slices, m, k), b_shape, (m, n * num_slices)) + if self.is_fused_moe_lora_fn(): + return self.matmul_shapes_fused_moe_lora( + m, + k, + n, + num_loras, + num_slices, + top_k_num, + num_experts, + ) raise ValueError(f"Unrecognized op_type {self}") def bench_fn(self) -> Callable: @@ -265,6 +368,16 @@ def bench_fn(self) -> Callable: return lora_shrink if self == OpType.LORA_EXPAND: return lora_expand + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_SHRINK, + OpType.FUSED_MOE_LORA_DOWN_SHRINK, + ]: + return fused_moe_lora_shrink + if self in [ + OpType.FUSED_MOE_LORA_GATE_UP_EXPAND, + OpType.FUSED_MOE_LORA_DOWN_EXPAND, + ]: + return fused_moe_lora_expand raise ValueError(f"Unrecognized optype {self}") @@ -318,6 +431,8 @@ class BenchmarkContext: sort_by_lora_id: bool dtype: torch.dtype seq_length: int | None = None + num_experts: int | None = None # num_experts for MoE based ops + top_k_num: int | None = None # top_k for MoE based ops num_slices: int | None = None # num_slices for slice based ops def with_seq_length(self, seq_length: int) -> "BenchmarkContext": @@ -373,6 +488,11 @@ def io_types(self) -> str: f"{dtype_to_str(self.output.dtype)}" ) + def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType): + return ( + size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size + ) + @staticmethod def make( ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" @@ -385,6 +505,8 @@ def make( ctx.lora_rank, ctx.num_loras, ctx.num_slices, + ctx.top_k_num, + ctx.num_experts, ) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) input_tensor, lora_weights, output_tensor = make_rand_tensors( @@ -432,17 +554,27 @@ def make( prompt_lora_indices_tensor, ) - def sanity_check(self) -> None: + def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None: """ Fails asserts when non-conformality is detected. """ - num_tokens = self.input.shape[-2] + num_tokens = ( + self.input.shape[1] + if op_type.is_fused_moe_lora_expand_fn() + else self.input.shape[-2] + ) # check metadata tensors - assert torch.sum(self.seq_lens) == num_tokens + ## In down shrink case, each token is repeated top_k_num times + assert num_tokens == self.get_num_tokens( + torch.sum(self.seq_lens), ctx.top_k_num, op_type + ), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}" num_seqs = self.seq_lens.shape[0] # assert self.seq_start_loc.shape[0] == num_seqs + ## In down shrink case, each prompt corresponds to top_k_num sequences assert self.prompt_lora_mapping.shape[0] == num_seqs - assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens + assert self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) def to_device(self, device: str): """ @@ -471,21 +603,111 @@ def to_device(tensor: torch.Tensor): to_device(field) if field_name != "no_lora_flag_cpu" else field, ) - def metadata(self) -> tuple[int, int, int]: + def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]: """ Return num_seqs, num_tokens and max_seq_len """ num_seqs = self.seq_lens.shape[0] - num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0] + num_tokens = self.get_num_tokens( + self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type + ) max_seq_len = torch.max(self.seq_lens).item() num_slices = len(self.lora_weights_lst) return num_seqs, num_tokens, max_seq_len, num_slices - def as_lora_shrink_kwargs(self) -> dict[str, Any]: - self.sanity_check() + def fused_moe_lora_data_prepare( + self, + block_size: int, + token_lora_mapping: torch.Tensor, + ctx: BenchmarkContext, + ): + def moe_lora_align_block_size( + topk_ids: torch.Tensor, + token_lora_mapping: torch.Tensor, + block_size: int, + num_experts: int, + max_loras: int, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return sorted_ids, expert_ids, num_tokens_post_pad + + num_tokens = ctx.batch_size + curr_topk_ids = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + topk_weights = torch.randint( + 0, + ctx.num_experts, + (num_tokens, ctx.top_k_num), + device="cuda", + dtype=torch.int32, + ) + + (sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = ( + moe_lora_align_block_size( + topk_ids=curr_topk_ids, + token_lora_mapping=token_lora_mapping, + block_size=block_size, + num_experts=ctx.num_experts, + max_loras=ctx.num_loras, + ) + ) + + sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1) + expert_ids = expert_ids_lora.view(ctx.num_loras, -1) + num_tokens_post_padded = num_tokens_post_padded_lora + return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) + + def as_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -520,11 +742,13 @@ def as_lora_shrink_kwargs(self) -> dict[str, Any]: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } - def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: - self.sanity_check() + def as_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) self.to_device(self.input.device) - _, num_tokens, _, num_slices = self.metadata() + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) # Sanity check matrix shapes. i_shape, lw_shape, o_shape = ( @@ -561,18 +785,173 @@ def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: "no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu, } + def as_fused_moe_lora_shrink_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + # Expected input shape : [num_tokens, hidden_size] for gate_up + # Expected input shape : [top_k_num * num_tokens, hidden_size] for down + assert len(i_shape) == 2 + assert i_shape[0] == num_tokens + hidden_size = i_shape[1] + # Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size] + assert len(lw_shape) == 4 + assert lw_shape[-1] == hidden_size + lora_rank = lw_shape[-2] + # Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(o_shape) == 4 + assert ( + o_shape + == (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank) + if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] + else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank) + ) + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "qcurr_hidden_states": self.input, + "lora_a_stacked": self.lora_weights_lst, + "a_intermediate_cache1": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "shrink_block_size_m": kernel_config["BLOCK_SIZE_M"], + "shrink_block_size_n": kernel_config["BLOCK_SIZE_N"], + "shrink_block_size_k": kernel_config["BLOCK_SIZE_K"], + "shrink_group_size_m": kernel_config["GROUP_SIZE_M"], + "shrink_num_warps": kernel_config["NUM_WARPS"], + "shrink_num_stages": kernel_config["NUM_STAGES"], + "shrink_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + + def as_fused_moe_lora_expand_kwargs( + self, ctx: BenchmarkContext, op_type: OpType + ) -> dict[str, Any]: + self.sanity_check(ctx, op_type) + self.to_device(self.input.device) + + _, num_tokens, _, num_slices = self.metadata(ctx, op_type) + + # Sanity check matrix shapes. + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) + + # Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank] + assert len(i_shape) == 4 + assert i_shape[0] == num_slices + assert i_shape[1] == num_tokens + lora_rank = i_shape[-1] + # Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank] + assert len(lw_shape) == 4 + assert lw_shape[-1] == lora_rank + hidden_size = lw_shape[-2] + # Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices] + assert len(o_shape) == 3 + assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices) + + kernel_config = get_lora_op_configs( + op_type.name.lower(), + max_loras=lw_shape[0], + batch=num_tokens, + hidden_size=hidden_size, + rank=lora_rank, + num_slices=num_slices, + add_inputs=False, + ) + + (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = ( + self.fused_moe_lora_data_prepare( + block_size=kernel_config["BLOCK_SIZE_M"], + token_lora_mapping=self.lora_kernel_meta.token_lora_mapping, + ctx=ctx, + ) + ) + + return { + "a_intermediate_cache1": self.input, + "lora_b_stacked": self.lora_weights_lst, + "output": self.output, + "topk_weights": topk_weights, + "sorted_token_ids": sorted_token_ids, + "expert_ids": expert_ids, + "num_tokens_post_padded": num_tokens_post_padded, + "top_k_num": ctx.top_k_num, + "device": self.input.device, + "N": lora_rank, + "M": topk_weights.shape[0], + "EM": sorted_token_ids.shape[1], + "K": self.input.shape[1], + "num_tokens": num_tokens, + "num_experts": ctx.num_experts, + "num_slices": num_slices, + "max_lora_rank": lora_rank, + "w1_output_dim_size": lw_shape[2], + "expand_block_size_m": kernel_config["BLOCK_SIZE_M"], + "expand_block_size_n": kernel_config["BLOCK_SIZE_N"], + "expand_block_size_k": kernel_config["BLOCK_SIZE_K"], + "expand_group_size_m": kernel_config["GROUP_SIZE_M"], + "expand_num_warps": kernel_config["NUM_WARPS"], + "expand_num_stages": kernel_config["NUM_STAGES"], + "expand_split_k": kernel_config.get("SPLIT_K", 1), + "mul_routed_weight": op_type.is_fused_moe_lora_down_fn(), + } + def bench_fn_kwargs( - self, op_type: OpType, add_inputs: bool | None = None + self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None ) -> dict[str, Any]: - if op_type.is_shrink_fn(): + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert add_inputs is None else: assert add_inputs is not None if op_type == OpType.LORA_SHRINK: - return self.as_lora_shrink_kwargs() + return self.as_lora_shrink_kwargs(ctx, op_type) if op_type == OpType.LORA_EXPAND: - return self.as_lora_expand_kwargs(add_inputs) + return self.as_lora_expand_kwargs(ctx, op_type, add_inputs) + if op_type.is_fused_moe_lora_shrink_fn(): + return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type) + if op_type.is_fused_moe_lora_expand_fn(): + return self.as_fused_moe_lora_expand_kwargs(ctx, op_type) raise ValueError(f"Unrecognized optype {self}") def test_correctness( @@ -617,7 +996,7 @@ def bench_optype( test_correctness: bool = False, ) -> TMeasurement: assert arg_pool_size >= 1 - if op_type.is_shrink_fn(): + if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn(): assert expand_fn_add_inputs is None else: assert expand_fn_add_inputs is not None @@ -627,23 +1006,30 @@ def bench_optype( BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) ] for bt in bench_tensors: - bt.sanity_check() + bt.sanity_check(ctx, op_type) # Test correctness of our implementation. if test_correctness: + assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], ( + f"Correctness testing is not supported for {op_type.name}." + ) assert all( - [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] + [ + bt.test_correctness(ctx, op_type, expand_fn_add_inputs) + for bt in bench_tensors + ] ) # BenchmarkTensors -> dict (kwargs) kwargs_list = [ - bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs) + bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs) for bt in bench_tensors ] # Clear LoRA optimization hash-maps. _LORA_A_PTR_DICT.clear() _LORA_B_PTR_DICT.clear() + _LORA_PTR_DICT.clear() # Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up for kwargs in kwargs_list: op_type.bench_fn()(**kwargs) @@ -793,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): # Benchmark bench_op expand_fn_add_inputs = ( - [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + [None] + if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn() + else args.expand_fn_add_inputs ) for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( @@ -831,12 +1219,22 @@ def as_benchmark_contexts( hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace ) -> list[BenchmarkContext]: ctxs: list[BenchmarkContext] = [] - for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa + for ( + batch_size, + hidden_size, + lora_rank, + num_loras, + sort_by_lora_id, + top_k_num, + num_experts, + ) in product( # noqa args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, args.sort_by_lora_id, + args.top_k_nums, + args.num_experts, ): ctxs.append( BenchmarkContext( @@ -851,6 +1249,8 @@ def as_benchmark_contexts( seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, + top_k_num=top_k_num, + num_experts=num_experts, # To be filled based on the OpType to benchmark num_slices=None, ) @@ -1012,6 +1412,22 @@ def add_common_command_args(p: argparse.ArgumentParser): ), ) + p.add_argument( + "--top-k-nums", + nargs="+", + type=int, + default=DEFAULT_TOP_K_NUMS, + help="Top-K values for MoE LoRA operations", + ) + + p.add_argument( + "--num-experts", + nargs="+", + type=int, + default=DEFAULT_NUM_EXPERTS, + help="Number of experts for MoE LoRA operations", + ) + parser = FlexibleArgumentParser( description=f""" Benchmark LoRA kernels: diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 318a0e58805d..91ab4a87c65f 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -158,6 +158,8 @@ def use_fused_moe_lora_kernel( "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "NUM_WARPS": 4, + "NUM_STAGES": 3, "SPLIT_K": 1, } @@ -182,6 +184,15 @@ def use_fused_moe_lora_kernel( config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], config["SPLIT_K"], mul_routed_weight, ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 7711f5c3208b..f5a766dd5e45 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -13,6 +13,7 @@ get_tensor_model_parallel_world_size, ) from vllm.lora.layers.base import BaseLayerWithLoRA +from vllm.lora.ops.triton_ops.utils import get_lora_op_configs from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import ( _get_config_dtype_str, @@ -39,6 +40,64 @@ def __init__(self, base_layer: FusedMoE) -> None: self.device = base_layer.w2_weight.device self._inject_lora_into_fused_moe() + def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]: + normalized_config = {} + for key, value in config.items(): + if key.islower(): + if key.startswith("block_"): + normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper() + else: + normalized_key = key.upper() + else: + normalized_key = key + normalized_config[normalized_key] = value + return normalized_config + + def _get_lora_moe_configs( + self, + op_prefix: str, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + num_slices: int, + M: int, + layer: FusedMoE, + top_k: int, + config_dtype: str, + ): + if envs.VLLM_TUNED_CONFIG_FOLDER: + shrink_config = get_lora_op_configs( + op_type=f"fused_moe_lora_{op_prefix}_shrink", + max_loras=lora_a_stacked.shape[0], + batch=M, + hidden_size=lora_a_stacked.shape[-1], + rank=lora_a_stacked.shape[-2], + num_slices=num_slices, + moe_intermediate_size=lora_b_stacked.shape[-2], + ) + expand_config = get_lora_op_configs( + op_type=f"fused_moe_lora_{op_prefix}_expand", + max_loras=lora_a_stacked.shape[0], + batch=M, + hidden_size=lora_a_stacked.shape[-1], + rank=lora_a_stacked.shape[-2], + num_slices=num_slices, + moe_intermediate_size=lora_b_stacked.shape[-2], + ) + else: # fall back to the default config + get_config_func = functools.partial( + try_get_optimal_moe_config, + layer.w13_weight.size(), + layer.w2_weight.size(), + top_k, + config_dtype, + block_shape=layer.quant_method.moe_quant_config.block_shape, + ) + shrink_config = get_config_func(M) + expand_config = get_config_func(M) + shrink_config = self._normalize_keys(shrink_config) + expand_config = self._normalize_keys(expand_config) + return shrink_config, expand_config + def _inject_lora_into_fused_moe(self): moe_state_dict = {} top_k = self.base_layer.top_k @@ -90,17 +149,19 @@ def wrapper(*args, **kwargs): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, + shrink_config, expand_config = self._get_lora_moe_configs( + op_prefix="w13", + lora_a_stacked=self.w1_lora_a_stacked, + lora_b_stacked=self.w1_lora_b_stacked, + num_slices=2, + M=M, + layer=layer, + top_k=top_k, + config_dtype=config_dtype, ) + # get the block size of m from customized config or default config max_loras = self.w1_lora_a_stacked.shape[0] - config = get_config_func(M) ( sorted_token_ids_lora, expert_ids_lora, @@ -108,7 +169,7 @@ def wrapper(*args, **kwargs): ) = self.punica_wrapper.moe_lora_align_block_size( curr_topk_ids, num_tokens, - config["BLOCK_SIZE_M"], + shrink_config["BLOCK_SIZE_M"], self.base_layer.local_num_experts, max_loras, self.adapter_enabled, @@ -138,7 +199,8 @@ def wrapper(*args, **kwargs): num_tokens_post_padded_lora, max_lora_rank, top_k, - config, + shrink_config, ## pass the shrink config + expand_config, ## pass the expand config self.adapter_enabled, ) @@ -164,17 +226,17 @@ def wrapper(*args, **kwargs): num_tokens = hidden_states.size(0) M = min(num_tokens, CHUNK_SIZE) - get_config_func = functools.partial( - try_get_optimal_moe_config, - layer.w13_weight.size(), - layer.w2_weight.size(), - top_k, - config_dtype, - block_shape=layer.quant_method.moe_quant_config.block_shape, + shrink_config, expand_config = self._get_lora_moe_configs( + op_prefix="w2", + lora_a_stacked=self.w2_lora_a_stacked, + lora_b_stacked=self.w2_lora_b_stacked, + num_slices=1, + M=M, + layer=layer, + top_k=top_k, + config_dtype=config_dtype, ) - config = get_config_func(M) - sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"] expert_ids_lora = moe_state_dict["expert_ids_lora"] num_tokens_post_padded_lora = moe_state_dict[ @@ -197,7 +259,8 @@ def wrapper(*args, **kwargs): num_tokens_post_padded_lora, max_lora_rank, top_k, - config, + shrink_config, ## pass the shrink config + expand_config, ## pass the expand config self.adapter_enabled, True, ) diff --git a/vllm/lora/ops/triton_ops/README_TUNING.md b/vllm/lora/ops/triton_ops/README_TUNING.md index fda95ea71891..d576e261557a 100644 --- a/vllm/lora/ops/triton_ops/README_TUNING.md +++ b/vllm/lora/ops/triton_ops/README_TUNING.md @@ -44,8 +44,17 @@ For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`. +For `fused_moe_lora_w13_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json`. + +For `fused_moe_lora_w13_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json`. + +For `fused_moe_lora_w2_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json`. + +For `fused_moe_lora_w2_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json`. + The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()` ### Json Structure -Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]` +Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]` +where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer. diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index 436ea4ed00c8..7e8b9a79add3 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -1,7 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora + +from vllm.lora.ops.triton_ops.fused_moe_lora_op import ( + fused_moe_lora, + fused_moe_lora_expand, + fused_moe_lora_shrink, +) from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink @@ -11,4 +16,6 @@ "lora_shrink", "LoRAKernelMeta", "fused_moe_lora", + "fused_moe_lora_shrink", + "fused_moe_lora_expand", ] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 539605c7c534..8f85f926aa4f 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -176,88 +176,50 @@ def _fused_moe_lora_kernel( @torch.inference_mode() -def _fused_moe_lora( - output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) +def _fused_moe_lora_shrink( + a_intermediate_cache1: torch.Tensor, + # (num_slices, num_tokens, top_k_num, max_lora_rank) qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) lora_a_stacked: list[ torch.Tensor ], # [(max_loras, num_experts, max_lora_rank, K,),...] - lora_b_stacked: list[ - torch.Tensor - ], # [(max_loras, num_experts, N, max_lora_rank,),...] topk_weights: torch.Tensor, # (num_tokens, top_k_num) sorted_token_ids: torch.Tensor, # (max_loras, _) expert_ids: torch.Tensor, # (max_loras, _ ,) num_tokens_post_padded: torch.Tensor, # (max_loras, ) - max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, block_size_m: int, block_size_n: int, block_size_k: int, group_size_m: int, + num_warps: int, + num_stages: int, split_k: int, mul_routed_weight: bool = False, ) -> None: - assert len(lora_a_stacked) == len(lora_b_stacked) > 0 - assert ( - sorted_token_ids.dim() - == expert_ids.dim() - == topk_weights.dim() - == qcurr_hidden_states.dim() - == 2 - ) - assert ( - sorted_token_ids.shape[0] - == expert_ids.shape[0] - == num_tokens_post_padded.shape[0] - ) - assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] - assert output.shape[0] == topk_weights.shape[0] - assert top_k_num == topk_weights.shape[1] - - for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked): - assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype - assert lora_a.dtype in [torch.float16, torch.bfloat16] - - device = qcurr_hidden_states.device - num_slices = len(lora_a_stacked) + w1_lora_a_stacked = lora_a_stacked[0] - config = { + shrink_config = { "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, "SPLIT_K": split_k, } - w1_lora_a_stacked = lora_a_stacked[0] - w1_lora_b_stacked = lora_b_stacked[0] - num_experts = lora_a_stacked[0].shape[1] - - N = max_lora_rank - M = topk_weights.shape[0] - EM = sorted_token_ids.shape[1] - K = qcurr_hidden_states.shape[1] - num_tokens = M * top_k_num - w1_output_dim_size = w1_lora_b_stacked.shape[2] - - lora_intermediate_cache1 = torch.zeros( - (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), - dtype=output.dtype, - device=device, - ) - - # slices - a_intermediate_size = num_slices * M * top_k_num * max_lora_rank - a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view( - num_slices, M, top_k_num, max_lora_rank - ) - b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view( - num_slices, M, top_k_num, w1_output_dim_size - ) - b_ptr = _get_ptr(lora_a_stacked, device) grid = lambda META: ( @@ -299,19 +261,70 @@ def _fused_moe_lora( num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, MUL_ROUTED_WEIGHT=False, - **config, + **shrink_config, ) + +@torch.inference_mode() +def _fused_moe_lora_expand( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: b_ptr = _get_ptr(lora_b_stacked, device) K = max_lora_rank N = w1_output_dim_size + w1_lora_b_stacked = lora_b_stacked[0] + a_intermediate_cache1 = a_intermediate_cache1.view( -1, a_intermediate_cache1.shape[3] ) - # Set split_k = 1 for expand calls - config["SPLIT_K"] = 1 + b_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, w1_output_dim_size), + dtype=output.dtype, + device=device, + ) + + expand_config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, # Set split_k = 1 for expand calls + } + grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), @@ -348,12 +361,142 @@ def _fused_moe_lora( num_slice_c=num_slices, top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, - **config, + **expand_config, ) for i in range(num_slices): output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i] +@torch.inference_mode() +def _fused_moe_lora( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, N, max_lora_rank,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: + assert len(lora_a_stacked) == len(lora_b_stacked) > 0 + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) + assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1] + assert output.shape[0] == topk_weights.shape[0] + assert top_k_num == topk_weights.shape[1] + device = qcurr_hidden_states.device + num_slices = len(lora_a_stacked) + w1_lora_b_stacked = lora_b_stacked[0] + num_experts = lora_a_stacked[0].shape[1] + N = max_lora_rank + M = topk_weights.shape[0] + EM = sorted_token_ids.shape[1] + K = qcurr_hidden_states.shape[1] + num_tokens = M * top_k_num + w1_output_dim_size = w1_lora_b_stacked.shape[2] + + a_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, max_lora_rank), + dtype=output.dtype, + device=device, + ) + + _fused_moe_lora_shrink( + a_intermediate_cache1, + qcurr_hidden_states, + lora_a_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + shrink_block_size_m, + shrink_block_size_n, + shrink_block_size_k, + shrink_group_size_m, + shrink_num_warps, + shrink_num_stages, + shrink_split_k, + mul_routed_weight, + ) + + _fused_moe_lora_expand( + output, + a_intermediate_cache1, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + max_lora_rank, + w1_output_dim_size, + expand_block_size_m, + expand_block_size_n, + expand_block_size_k, + expand_group_size_m, + expand_num_warps, + expand_num_stages, + expand_split_k, + mul_routed_weight, + ) + + def _fused_moe_lora_fake( output: torch.Tensor, qcurr_hidden_states: torch.Tensor, @@ -367,10 +510,84 @@ def _fused_moe_lora_fake( top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_shrink_fake( + a_intermediate_cache1: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_expand_fake( + output: torch.Tensor, + a_intermediate_cache1: torch.Tensor, + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, block_size_m: int, block_size_n: int, block_size_k: int, group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, mul_routed_weight: bool = False, ) -> None: return @@ -383,7 +600,26 @@ def _fused_moe_lora_fake( mutates_args=["output"], fake_impl=_fused_moe_lora_fake, ) + + direct_register_custom_op( + op_name="fused_moe_lora_shrink", + op_func=_fused_moe_lora_shrink, + mutates_args=["a_intermediate_cache1"], + fake_impl=_fused_moe_lora_shrink_fake, + ) + + direct_register_custom_op( + op_name="fused_moe_lora_expand", + op_func=_fused_moe_lora_expand, + mutates_args=["output"], + fake_impl=_fused_moe_lora_expand_fake, + ) + fused_moe_lora = torch.ops.vllm.fused_moe_lora + fused_moe_lora_shrink = torch.ops.vllm.fused_moe_lora_shrink + fused_moe_lora_expand = torch.ops.vllm.fused_moe_lora_expand except AttributeError: fused_moe_lora = _fused_moe_lora + fused_moe_lora_shrink = _fused_moe_lora_shrink + fused_moe_lora_expand = _fused_moe_lora_expand diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 368c5037d2e4..bd413a6db26b 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -154,13 +154,13 @@ def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None: gpu_name = gpu_name.replace("-", "_") config_fname = None - if op_type == "shrink": - config_fname = f"{gpu_name}_{op_type.upper()}.json" - else: - assert op_type == "expand" + # only expand op needs to consider add_inputs + if op_type == "expand": config_fname = ( f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json" ) + else: + config_fname = f"{gpu_name}_{op_type.upper()}.json" config_path = Path(f"{user_defined_config_folder}/{config_fname}") if not config_path.exists(): @@ -186,8 +186,17 @@ def get_lora_op_configs( rank: int, num_slices: int, add_inputs: bool | None = None, + moe_intermediate_size: int | None = None, ) -> dict[str, int | None]: - assert op_type in ["shrink", "expand"] + # Add support for fused_moe_lora ops + assert op_type in [ + "shrink", + "expand", + "fused_moe_lora_w13_shrink", + "fused_moe_lora_w13_expand", + "fused_moe_lora_w2_shrink", + "fused_moe_lora_w2_expand", + ] # default config default = {} @@ -203,6 +212,22 @@ def get_lora_op_configs( "num_stages": 2, "max_nreg": None, } + # The default config for fused_moe_lora ops + elif op_type in [ + "fused_moe_lora_w13_shrink", + "fused_moe_lora_w13_expand", + "fused_moe_lora_w2_shrink", + "fused_moe_lora_w2_expand", + ]: + default = { + "block_m": 64, + "block_n": 64, + "block_k": 32, + "num_warps": 4, + "num_stages": 3, + "group_size_m": 8, + "split_k": 1, + } else: default = { "block_m": 64, @@ -247,5 +272,13 @@ def get_lora_op_configs( or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))] ) + # slice by moe-intermediate-size if applicable + if moe_intermediate_size is not None: + i = moe_intermediate_size + config_data = ( + config_data.get(str(i)) + or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))] + ) + assert config_data is not None return config_data diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index c552412cfd62..b6186e856152 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -479,7 +479,8 @@ def add_lora_fused_moe( num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, - config, + shrink_config, + expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, ): diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 30def90380db..1bb80e516d3f 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -367,7 +367,8 @@ def add_lora_fused_moe( num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, - config, + shrink_config, + expand_config, adapter_enabled: torch.Tensor, mul_routed_weight=False, ): @@ -388,10 +389,19 @@ def add_lora_fused_moe( top_k_num, lora_ids, adapter_enabled, - config["BLOCK_SIZE_M"], - config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], - config["GROUP_SIZE_M"], - config.get("SPLIT_K", 1), + shrink_config.get("BLOCK_SIZE_M", 64), + shrink_config.get("BLOCK_SIZE_N", 64), + shrink_config.get("BLOCK_SIZE_K", 32), + shrink_config.get("GROUP_SIZE_M", 8), + shrink_config.get("NUM_WARPS", 4), + shrink_config.get("NUM_STAGES", 3), + shrink_config.get("SPLIT_K", 1), + expand_config.get("BLOCK_SIZE_M", 64), + expand_config.get("BLOCK_SIZE_N", 64), + expand_config.get("BLOCK_SIZE_K", 32), + expand_config.get("GROUP_SIZE_M", 8), + expand_config.get("NUM_WARPS", 4), + expand_config.get("NUM_STAGES", 3), + expand_config.get("SPLIT_K", 1), mul_routed_weight, )