Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,9 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
has_act_order, is_k_full, max_shared_mem);
}

int group_tensor_size =
(!is_k_full && has_act_order) ? prob_k / num_groups : group_size;

TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
Expand Down Expand Up @@ -1826,8 +1829,8 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr =
(const int4*)s +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 8) *
((group_tensor_size == -1 ? 1 : prob_k / group_tensor_size) * prob_n /
8) *
expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
Expand Down
32 changes: 23 additions & 9 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
Expand All @@ -151,6 +152,7 @@ def test_fused_marlin_moe(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
seed_everything(7)

Expand All @@ -163,6 +165,9 @@ def test_fused_marlin_moe(
return
if group_size in (k, n):
return
else:
if not is_k_full:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
Expand Down Expand Up @@ -243,6 +248,7 @@ def test_fused_marlin_moe(
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
is_k_full=is_k_full,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2
Expand All @@ -258,6 +264,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply(
m: int,
n: int,
Expand All @@ -267,6 +274,7 @@ def test_single_marlin_moe_multiply(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
if topk > e:
return
Expand All @@ -277,6 +285,9 @@ def test_single_marlin_moe_multiply(
return
if group_size == k:
return
else:
if not is_k_full:
return

quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
Expand Down Expand Up @@ -307,15 +318,18 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits)
marlin_output = single_marlin_moe(
a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
8 changes: 5 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def single_marlin_moe(
renormalize: bool,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
Expand Down Expand Up @@ -86,7 +87,7 @@ def single_marlin_moe(

intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk,
g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk,
block_size_m, True, False)

return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
Expand All @@ -107,6 +108,7 @@ def fused_marlin_moe(
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
Expand Down Expand Up @@ -199,7 +201,7 @@ def fused_marlin_moe(
M,
2 * N,
K,
True,
is_k_full,
E,
topk,
block_size_m,
Expand All @@ -223,7 +225,7 @@ def fused_marlin_moe(
M,
K,
N,
True,
is_k_full,
E,
topk,
block_size_m,
Expand Down