Skip to content

Commit 428291e

Browse files
[Feature] Add TritonBF16MoEMethod for BF16 MoE inference
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 78b5462 commit 428291e

4 files changed

Lines changed: 364 additions & 2 deletions

File tree

fastdeploy/model_executor/layers/moe/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
CutlassW4AFP8MoEMethod,
1818
CutlassWeightOnlyMoEMethod,
1919
)
20-
from .fused_moe_triton_backend import TritonWeightOnlyMoEMethod
20+
from .fused_moe_triton_backend import TritonBF16MoEMethod, TritonWeightOnlyMoEMethod
2121
from .moe import FusedMoE
2222

2323
__all__ = [
@@ -26,4 +26,5 @@
2626
CutlassW4AFP8MoEMethod,
2727
FusedMoE,
2828
TritonWeightOnlyMoEMethod,
29+
TritonBF16MoEMethod,
2930
]

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
from ..quantization.quant_base import QuantMethodBase
3636

3737
try:
38+
import triton.language as tl
3839
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func
3940

40-
from .triton_moe_kernels import fused_moe_kernel_paddle
41+
from .triton_moe_kernels import fused_moe_kernel_bf16, fused_moe_kernel_paddle
4142
except ImportError:
4243
pass
4344
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
@@ -1885,3 +1886,240 @@ def apply(
18851886
fc1_latent_proj,
18861887
fc2_latent_proj,
18871888
)
1889+
1890+
1891+
class TritonBF16MoEMethod(QuantMethodBase):
1892+
"""
1893+
Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE.
1894+
1895+
Activated via: export FD_MOE_BACKEND=triton
1896+
Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj.
1897+
This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA.
1898+
"""
1899+
1900+
# Class-level flag: print the "Triton BF16 MoE activated" message only once.
1901+
_logged = False
1902+
1903+
def __init__(self, quant_config=None):
1904+
self.quant_config = quant_config
1905+
self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"]
1906+
if not TritonBF16MoEMethod._logged:
1907+
import logging
1908+
1909+
logging.getLogger(__name__).warning(
1910+
"[TritonBF16MoEMethod] Triton BF16 MoE backend is ACTIVE "
1911+
"(FD_MOE_BACKEND=triton). Using fused_moe_kernel_paddle for BF16."
1912+
)
1913+
TritonBF16MoEMethod._logged = True
1914+
1915+
def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None:
1916+
pass
1917+
1918+
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
1919+
"""
1920+
Reuse UnquantizedFusedMoEMethod weight creation logic.
1921+
Weight shapes on CUDA (non-torch format):
1922+
up_gate_proj_weight: [E, hidden_size, moe_intermediate_size * 2] (K-major)
1923+
down_proj_weight: [E, moe_intermediate_size, hidden_size] (K-major)
1924+
The Triton kernel reads B as [E, K, N] which maps directly to these shapes.
1925+
"""
1926+
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import UnquantizedFusedMoEMethod
1927+
1928+
UnquantizedFusedMoEMethod.create_weights(self, layer, **extra_weight_attrs)
1929+
1930+
def process_weights_after_loading(self, layer: nn.Layer):
1931+
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import UnquantizedFusedMoEMethod
1932+
1933+
UnquantizedFusedMoEMethod.process_weights_after_loading(self, layer)
1934+
1935+
def process_loaded_weights(self, layer: nn.Layer, state_dict):
1936+
"""Stack individual expert weights into the stacked parameter."""
1937+
up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict)
1938+
layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0))
1939+
layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0))
1940+
1941+
def _get_default_config(self, M: int, N: int, K: int) -> dict:
1942+
"""
1943+
Heuristic tile config for BF16 MoE, mirroring vLLM's get_default_config logic.
1944+
M: number of token-expert pairs (post-padded) / BLOCK_SIZE_M
1945+
N: output dimension of the GEMM
1946+
K: input dimension of the GEMM
1947+
"""
1948+
if M <= 32:
1949+
block_m, block_n, block_k = 16, 64, 64
1950+
elif M <= 512:
1951+
block_m, block_n, block_k = 32, 128, 64
1952+
else:
1953+
block_m, block_n, block_k = 128, 128, 64
1954+
return {
1955+
"BLOCK_SIZE_M": block_m,
1956+
"BLOCK_SIZE_N": block_n,
1957+
"BLOCK_SIZE_K": block_k,
1958+
"GROUP_SIZE_M": 8,
1959+
}
1960+
1961+
def apply(
1962+
self,
1963+
layer: nn.Layer,
1964+
x: paddle.Tensor,
1965+
gate: nn.Layer,
1966+
topk_ids_hookfunc: Callable = None,
1967+
shared_experts: nn.Layer = None,
1968+
) -> paddle.Tensor:
1969+
"""
1970+
BF16 Triton Fused MoE forward.
1971+
1972+
Pipeline:
1973+
1. Gate + topk routing
1974+
2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded
1975+
3. fused_moe_kernel_paddle GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N]
1976+
4. SwiGLU activation
1977+
5. fused_moe_kernel_paddle GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K]
1978+
(with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication)
1979+
6. Reshape + sum over topk dim
1980+
"""
1981+
import fastdeploy
1982+
1983+
token_num = x.shape[0]
1984+
if token_num == 0:
1985+
return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype)
1986+
1987+
top_k = layer.top_k
1988+
num_local_experts = layer.num_local_experts
1989+
moe_intermediate_size = layer.moe_intermediate_size
1990+
hidden_size = layer.hidden_size
1991+
1992+
# --- 1. Routing ---
1993+
gate_out = gate(x)
1994+
gate_out = gate_out.cast("float32")
1995+
1996+
if layer.topk_method == "noaux_tc":
1997+
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
1998+
1999+
_, topk_weights, topk_ids = get_moe_scores(
2000+
gate_out,
2001+
layer.n_group,
2002+
layer.topk_group,
2003+
top_k,
2004+
layer.routed_scaling_factor,
2005+
layer.gate_correction_bias,
2006+
getattr(layer, "renormalize", True),
2007+
topk_reduce_func=getattr(layer, "topk_reduce_func", None),
2008+
)
2009+
else:
2010+
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
2011+
gate_out,
2012+
layer.gate_correction_bias,
2013+
top_k,
2014+
True, # apply_norm_weight
2015+
False,
2016+
)
2017+
2018+
if topk_ids_hookfunc is not None:
2019+
topk_ids_hookfunc(topk_ids=topk_ids)
2020+
2021+
# --- 2. Preprocess: sort tokens by expert assignment ---
2022+
# Choose BLOCK_SIZE_M based on decode vs prefill heuristic
2023+
num_token_expert_pairs = token_num * top_k
2024+
cfg = self._get_default_config(num_token_expert_pairs, moe_intermediate_size * 2, hidden_size)
2025+
2026+
sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func(
2027+
topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"]
2028+
)
2029+
max_possible_num_post_padded = sorted_token_ids.shape[0]
2030+
2031+
# --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) ---
2032+
# up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn
2033+
up_gate_proj_out = paddle.empty(
2034+
[num_token_expert_pairs, moe_intermediate_size * 2],
2035+
dtype=x.dtype,
2036+
)
2037+
grid1 = (
2038+
ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"])
2039+
* ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]),
2040+
)
2041+
fused_moe_kernel_bf16[grid1](
2042+
x,
2043+
layer.up_gate_proj_weight,
2044+
up_gate_proj_out,
2045+
None, # topk_weights_ptr (no weight mul on GEMM1)
2046+
sorted_token_ids,
2047+
expert_ids,
2048+
num_tokens_post_padded,
2049+
N=moe_intermediate_size * 2,
2050+
K=hidden_size,
2051+
EM=max_possible_num_post_padded,
2052+
num_valid_tokens=num_token_expert_pairs,
2053+
stride_am=x.strides[0],
2054+
stride_ak=x.strides[1],
2055+
stride_be=layer.up_gate_proj_weight.strides[0],
2056+
stride_bk=layer.up_gate_proj_weight.strides[1],
2057+
stride_bn=layer.up_gate_proj_weight.strides[2],
2058+
stride_cm=up_gate_proj_out.strides[0],
2059+
stride_cn=up_gate_proj_out.strides[1],
2060+
BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"],
2061+
BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"],
2062+
BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"],
2063+
GROUP_SIZE_M=cfg["GROUP_SIZE_M"],
2064+
MUL_ROUTED_WEIGHT=False,
2065+
top_k=top_k,
2066+
compute_type=tl.bfloat16,
2067+
)
2068+
2069+
# --- 4. SwiGLU activation ---
2070+
down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out)
2071+
2072+
# --- 5. GEMM2: inter -> hidden, fuse router weight multiplication ---
2073+
# down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn
2074+
down_proj_out = paddle.empty(
2075+
(num_token_expert_pairs, hidden_size),
2076+
dtype=x.dtype,
2077+
)
2078+
cfg2 = self._get_default_config(num_token_expert_pairs, hidden_size, moe_intermediate_size)
2079+
grid2 = (
2080+
ceil_div(max_possible_num_post_padded, cfg2["BLOCK_SIZE_M"])
2081+
* ceil_div(hidden_size, cfg2["BLOCK_SIZE_N"]),
2082+
)
2083+
fused_moe_kernel_bf16[grid2](
2084+
down_proj_input,
2085+
layer.down_proj_weight,
2086+
down_proj_out,
2087+
topk_weights,
2088+
sorted_token_ids,
2089+
expert_ids,
2090+
num_tokens_post_padded,
2091+
N=hidden_size,
2092+
K=moe_intermediate_size,
2093+
EM=max_possible_num_post_padded,
2094+
num_valid_tokens=num_token_expert_pairs,
2095+
stride_am=down_proj_input.strides[0],
2096+
stride_ak=down_proj_input.strides[1],
2097+
stride_be=layer.down_proj_weight.strides[0],
2098+
stride_bk=layer.down_proj_weight.strides[1],
2099+
stride_bn=layer.down_proj_weight.strides[2],
2100+
stride_cm=down_proj_out.strides[0],
2101+
stride_cn=down_proj_out.strides[1],
2102+
BLOCK_SIZE_M=cfg2["BLOCK_SIZE_M"],
2103+
BLOCK_SIZE_N=cfg2["BLOCK_SIZE_N"],
2104+
BLOCK_SIZE_K=cfg2["BLOCK_SIZE_K"],
2105+
GROUP_SIZE_M=cfg2["GROUP_SIZE_M"],
2106+
MUL_ROUTED_WEIGHT=True, # fuse router weight * output
2107+
# top_k=1: down_proj_input rows are indexed directly by sorted_token_ids,
2108+
# so a_ptrs = base + offs_token * stride_am (no // top_k needed).
2109+
top_k=1,
2110+
compute_type=tl.bfloat16,
2111+
)
2112+
2113+
# --- 6. Reduce over topk ---
2114+
down_proj_out.reshape_([token_num, top_k, hidden_size])
2115+
out = down_proj_out.sum(axis=1)
2116+
return out
2117+
2118+
def apply_ep_prefill(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None):
2119+
raise NotImplementedError("TritonBF16MoEMethod does not support EP prefill yet.")
2120+
2121+
def apply_ep_decode(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None):
2122+
raise NotImplementedError("TritonBF16MoEMethod does not support EP decode yet.")
2123+
2124+
def apply_tp(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None):
2125+
return self.apply(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc)

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ def get_moe_method(layer=None):
5454
"""
5555

5656
if current_platform.is_cuda():
57+
moe_backend = envs.FD_MOE_BACKEND.lower()
58+
if moe_backend == "triton":
59+
from paddleformers.utils.log import logger
60+
61+
from .fused_moe_triton_backend import TritonBF16MoEMethod
62+
63+
logger.info("[get_moe_method] FD_MOE_BACKEND=triton -> TritonBF16MoEMethod")
64+
return TritonBF16MoEMethod(None)
5765
from .fused_moe_cutlass_backend import CutlassMoEMethod
5866

5967
return CutlassMoEMethod(None)

0 commit comments

Comments
 (0)