Skip to content

Commit 6965ef4

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Performance][DeepGEMM] Estimate expected_m (#28694)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent c9e6658 commit 6965ef4

File tree

3 files changed

+73
-17
lines changed

3 files changed

+73
-17
lines changed

tests/kernels/moe/test_deepep_deepgemm_moe.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
"""
88

99
import dataclasses
10+
from contextlib import contextmanager
1011

1112
import pytest
1213
import torch.distributed
1314
from torch.distributed import ProcessGroup
1415
from typing_extensions import ParamSpec
1516

1617
from vllm.config import VllmConfig, set_current_vllm_config
18+
from vllm.forward_context import set_forward_context
1719
from vllm.model_executor.layers.fused_moe.config import (
1820
FusedMoEQuantConfig,
1921
fp8_w8a8_moe_quant_config,
@@ -61,6 +63,23 @@
6163
P = ParamSpec("P")
6264

6365

66+
@contextmanager
67+
def with_dp_metadata(M: int, world_size: int):
68+
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
69+
70+
vllm_config = VllmConfig()
71+
vllm_config.parallel_config.data_parallel_size = world_size
72+
vllm_config.parallel_config.enable_expert_parallel = True
73+
74+
with set_forward_context(
75+
None,
76+
vllm_config,
77+
num_tokens=M,
78+
num_tokens_across_dp=num_tokens_across_dp,
79+
):
80+
yield
81+
82+
6483
def next_power_of_2(x):
6584
import math
6685

@@ -285,18 +304,21 @@ def build_expert_map():
285304
quant_config=quant_config,
286305
)
287306

288-
out = mk.forward(
289-
hidden_states=test_tensors.rank_tokens,
290-
w1=w1,
291-
w2=w2,
292-
topk_weights=test_tensors.topk_weights,
293-
topk_ids=test_tensors.topk,
294-
inplace=False,
295-
activation="silu",
296-
global_num_experts=num_experts,
297-
expert_map=build_expert_map(),
298-
apply_router_weight_on_input=False,
299-
)
307+
with with_dp_metadata(
308+
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
309+
):
310+
out = mk.forward(
311+
hidden_states=test_tensors.rank_tokens,
312+
w1=w1,
313+
w2=w2,
314+
topk_weights=test_tensors.topk_weights,
315+
topk_ids=test_tensors.topk,
316+
inplace=False,
317+
activation="silu",
318+
global_num_experts=num_experts,
319+
expert_map=build_expert_map(),
320+
apply_router_weight_on_input=False,
321+
)
300322
return out
301323

302324

vllm/forward_context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ def get_forward_context() -> ForwardContext:
221221
return _forward_context
222222

223223

224+
def is_forward_context_available() -> bool:
225+
return _forward_context is not None
226+
227+
224228
def create_forward_context(
225229
attn_metadata: Any,
226230
vllm_config: VllmConfig,

vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
8+
from vllm.forward_context import get_forward_context, is_forward_context_available
89
from vllm.logger import init_logger
910
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
1011
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@@ -19,7 +20,7 @@
1920
get_mk_alignment_for_contiguous_layout,
2021
is_deep_gemm_e8m0_used,
2122
)
22-
from vllm.utils.math_utils import cdiv
23+
from vllm.utils.math_utils import cdiv, round_up
2324

2425
logger = init_logger(__name__)
2526

@@ -313,6 +314,33 @@ def workspace_shapes(
313314
output = (num_experts, max_num_tokens * num_dispatchers, K)
314315
return (workspace13, workspace2, output)
315316

317+
def estimate_expected_m(
318+
self, global_num_experts: int, max_tokens_per_expert: int, topk: int
319+
) -> int:
320+
dp_meta = (
321+
get_forward_context().dp_metadata
322+
if is_forward_context_available()
323+
else None
324+
)
325+
if dp_meta is None:
326+
logger.warning_once(
327+
"DPMetadata unavailable. Defaulting expected_m to "
328+
f"{max_tokens_per_expert}.",
329+
scope="local",
330+
)
331+
return max_tokens_per_expert
332+
333+
total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
334+
total_num_tokens_replicated = total_num_tokens * topk
335+
336+
# Assume even load balancing
337+
assert global_num_experts != 0
338+
estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
339+
# clamp estimate
340+
estimate = max(estimate, 16)
341+
estimate = min(max_tokens_per_expert, estimate)
342+
return estimate
343+
316344
def apply(
317345
self,
318346
output: torch.Tensor,
@@ -348,10 +376,12 @@ def apply(
348376

349377
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
350378

351-
# (from deepgemm docs) : A value hint (which is a value on CPU)
352-
# for the M expectation of each batch, correctly setting this value
353-
# may lead to better performance.
354-
expected_m = max_num_tokens
379+
expected_m = self.estimate_expected_m(
380+
global_num_experts=global_num_experts,
381+
max_tokens_per_expert=max_num_tokens,
382+
topk=topk_ids.size(-1),
383+
)
384+
355385
fp8_m_grouped_gemm_nt_masked(
356386
(a1q, a1q_scale),
357387
(w1, self.w1_scale),

0 commit comments

Comments
 (0)