Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
567ef62
Adding config loading and benchmarking for fused_moe_lora shrink and …
yugong333 Oct 28, 2025
d3364e9
fix some bugs
yugong333 Oct 21, 2025
9a5f9e0
your message
yugong333 Oct 21, 2025
7d3071d
fix bugs
yugong333 Oct 21, 2025
1724bdb
fix bugs
yugong333 Oct 21, 2025
e8d144f
Fixed the bugs
yugong333 Oct 21, 2025
64537e2
Adding pre-commit-config.yaml
yugong333 Oct 21, 2025
b09466c
clean the code
yugong333 Oct 23, 2025
d9cb741
fix bugs
yugong333 Oct 23, 2025
0bf6a53
Adding support in benchmark_lora for fused_moe_lora expand and shrink…
yugong333 Oct 26, 2025
94508e4
Adding data generation for fused_moe_lora
yugong333 Oct 26, 2025
6c8c97b
Fix bugs
yugong333 Oct 28, 2025
8e85f95
Adding accuracy test
yugong333 Oct 28, 2025
0ee933b
fix bugs
yugong333 Oct 28, 2025
f9f0f8e
fix bugs
yugong333 Oct 28, 2025
e11030e
clean code
yugong333 Oct 28, 2025
82635d5
fix pre-commit
yugong333 Oct 28, 2025
5553bda
fix bugs
yugong333 Oct 28, 2025
0d8fa61
clean code
yugong333 Oct 29, 2025
3b1f04a
clean code
yugong333 Oct 31, 2025
3ad93dd
clean code
yugong333 Nov 1, 2025
3f6357f
clean code
yugong333 Nov 1, 2025
3acf93b
restore pre-commit-config.yaml
yugong333 Nov 1, 2025
ff518b3
restore .pre-commit-config.yaml
yugong333 Nov 1, 2025
dfb9dd1
clean code
yugong333 Nov 1, 2025
d950b3d
clean code
yugong333 Nov 1, 2025
65c11e9
clean code
yugong333 Nov 2, 2025
f451ca7
rename the config
yugong333 Nov 3, 2025
22faf7e
clean code
yugong333 Nov 3, 2025
0b439f7
fix format issue
yugong333 Nov 3, 2025
a1ec116
Rabase PR
yugong333 Nov 3, 2025
d73f410
Renaming kernel
yugong333 Nov 3, 2025
51f00b2
renaming
yugong333 Nov 3, 2025
50afb56
Normalize key name as uppercase
yugong333 Nov 4, 2025
221b287
fix bugs
yugong333 Nov 4, 2025
1542c93
fix bugs
yugong333 Nov 4, 2025
623278e
Merge branch 'main' into restore-pr
jeejeelee Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
476 changes: 446 additions & 30 deletions benchmarks/kernels/benchmark_lora.py

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions tests/lora/test_fused_moe_lora_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def use_fused_moe_lora_kernel(
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"NUM_WARPS": 4,
"NUM_STAGES": 3,
"SPLIT_K": 1,
}

Expand All @@ -182,6 +184,15 @@ def use_fused_moe_lora_kernel(
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
config["NUM_WARPS"],
config["NUM_STAGES"],
config["SPLIT_K"],
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
config["GROUP_SIZE_M"],
config["NUM_WARPS"],
config["NUM_STAGES"],
config["SPLIT_K"],
mul_routed_weight,
)
Expand Down
92 changes: 72 additions & 20 deletions vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_tensor_model_parallel_world_size,
)
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
Expand All @@ -39,6 +40,50 @@ def __init__(self, base_layer: FusedMoE) -> None:
self.device = base_layer.w2_weight.device
self._inject_lora_into_fused_moe()

def _get_lora_moe_configs(
self,
op_prefix: str,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
num_slices: int,
M: int,
layer: FusedMoE,
top_k: int,
config_dtype: str,
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=lora_a_stacked.shape[0],
batch=M,
hidden_size=lora_a_stacked.shape[-1],
rank=lora_a_stacked.shape[-2],
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked.shape[-2],
)
expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand",
max_loras=lora_a_stacked.shape[0],
batch=M,
hidden_size=lora_a_stacked.shape[-1],
rank=lora_a_stacked.shape[-2],
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked.shape[-2],
)
else: # fall back to the default config
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
shrink_config = get_config_func(M)
expand_config = get_config_func(M)

return shrink_config, expand_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to lower-case all keys of the configs here so we dont have to do the
config.get("UPPER_CASE", None) or config.get("lower_case") check everywhere ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @varun-sundar-rabindranath I add an function to normalize the key name into upper case to be consistent with the moe layer config name, which also is able to remove the name check everywhere.


def _inject_lora_into_fused_moe(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
Expand Down Expand Up @@ -90,25 +135,30 @@ def wrapper(*args, **kwargs):
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)

get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13",
lora_a_stacked=self.w1_lora_a_stacked,
lora_b_stacked=self.w1_lora_b_stacked,
num_slices=2,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)

# get the block size of m from customized config or default config
max_loras = self.w1_lora_a_stacked.shape[0]
config = get_config_func(M)
block_size = shrink_config.get("BLOCK_SIZE_M") or shrink_config.get(
"block_m", 64
)
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.punica_wrapper.moe_lora_align_block_size(
curr_topk_ids,
num_tokens,
config["BLOCK_SIZE_M"],
block_size,
self.base_layer.local_num_experts,
max_loras,
self.adapter_enabled,
Expand Down Expand Up @@ -138,7 +188,8 @@ def wrapper(*args, **kwargs):
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
config,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
)

Expand All @@ -164,17 +215,17 @@ def wrapper(*args, **kwargs):
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)

get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
lora_a_stacked=self.w2_lora_a_stacked,
lora_b_stacked=self.w2_lora_b_stacked,
num_slices=1,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)

config = get_config_func(M)

sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
expert_ids_lora = moe_state_dict["expert_ids_lora"]
num_tokens_post_padded_lora = moe_state_dict[
Expand All @@ -197,7 +248,8 @@ def wrapper(*args, **kwargs):
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
config,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
True,
)
Expand Down
11 changes: 10 additions & 1 deletion vllm/lora/ops/triton_ops/README_TUNING.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,17 @@ For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA

For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.

For `fused_moe_lora_w13_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json`.

For `fused_moe_lora_w13_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json`.

For `fused_moe_lora_w2_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json`.

For `fused_moe_lora_w2_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json`.

The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`

### Json Structure

Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]`
where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer.
9 changes: 8 additions & 1 deletion vllm/lora/ops/triton_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora

from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
fused_moe_lora,
fused_moe_lora_expand,
fused_moe_lora_shrink,
)
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
Expand All @@ -11,4 +16,6 @@
"lora_shrink",
"LoRAKernelMeta",
"fused_moe_lora",
"fused_moe_lora_shrink",
"fused_moe_lora_expand",
]
Loading