Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)

# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")

message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")

Expand Down Expand Up @@ -681,6 +680,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_PERMUTE_SRC
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
"csrc/moe/moe_permute_unpermute_op.cu")

set_gencode_flags_for_srcs(
SRCS "${MARLIN_PERMUTE_SRC}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused by this - should it be MOE_PERMUTE_SRC?

Copy link
Contributor Author

@CalebDu CalebDu May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line should be removed with bnell‘s review. I'll remove it.

CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")

list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
endif()
message(STATUS "Enabling moe extension.")
define_gpu_extension_target(
_moe_C
Expand All @@ -689,6 +699,8 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

Expand Down
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,

score = torch.randn((m, num_experts), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, renormalize=False)

def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def run():
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
if use_deep_gemm:
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, False)
return fused_experts(
x,
w1,
Expand Down
Loading