|
35 | 35 | from ..quantization.quant_base import QuantMethodBase |
36 | 36 |
|
37 | 37 | try: |
| 38 | + import triton.language as tl |
38 | 39 | from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess_func |
39 | 40 |
|
40 | | - from .triton_moe_kernels import fused_moe_kernel_paddle |
| 41 | + from .triton_moe_kernels import fused_moe_kernel_bf16, fused_moe_kernel_paddle |
41 | 42 | except ImportError: |
42 | 43 | pass |
43 | 44 | from fastdeploy.model_executor.layers.moe.moe import get_moe_scores |
@@ -1885,3 +1886,240 @@ def apply( |
1885 | 1886 | fc1_latent_proj, |
1886 | 1887 | fc2_latent_proj, |
1887 | 1888 | ) |
| 1889 | + |
| 1890 | + |
| 1891 | +class TritonBF16MoEMethod(QuantMethodBase): |
| 1892 | + """ |
| 1893 | + Use Triton Group Gemm (BF16 unquantized) to compute Fused MoE. |
| 1894 | +
|
| 1895 | + Activated via: export FD_MOE_BACKEND=triton |
| 1896 | + Weight layout (CUDA path): [E, K, 2N] for up_gate_proj, [E, N, K] for down_proj. |
| 1897 | + This matches UnquantizedFusedMoEMethod.create_weights layout on CUDA. |
| 1898 | + """ |
| 1899 | + |
| 1900 | + # Class-level flag: print the "Triton BF16 MoE activated" message only once. |
| 1901 | + _logged = False |
| 1902 | + |
| 1903 | + def __init__(self, quant_config=None): |
| 1904 | + self.quant_config = quant_config |
| 1905 | + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] |
| 1906 | + if not TritonBF16MoEMethod._logged: |
| 1907 | + import logging |
| 1908 | + |
| 1909 | + logging.getLogger(__name__).warning( |
| 1910 | + "[TritonBF16MoEMethod] Triton BF16 MoE backend is ACTIVE " |
| 1911 | + "(FD_MOE_BACKEND=triton). Using fused_moe_kernel_paddle for BF16." |
| 1912 | + ) |
| 1913 | + TritonBF16MoEMethod._logged = True |
| 1914 | + |
| 1915 | + def process_prequanted_weights(self, layer: nn.Layer, state_dict, is_rearrange: bool = False) -> None: |
| 1916 | + pass |
| 1917 | + |
| 1918 | + def create_weights(self, layer: nn.Layer, **extra_weight_attrs): |
| 1919 | + """ |
| 1920 | + Reuse UnquantizedFusedMoEMethod weight creation logic. |
| 1921 | + Weight shapes on CUDA (non-torch format): |
| 1922 | + up_gate_proj_weight: [E, hidden_size, moe_intermediate_size * 2] (K-major) |
| 1923 | + down_proj_weight: [E, moe_intermediate_size, hidden_size] (K-major) |
| 1924 | + The Triton kernel reads B as [E, K, N] which maps directly to these shapes. |
| 1925 | + """ |
| 1926 | + from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import UnquantizedFusedMoEMethod |
| 1927 | + |
| 1928 | + UnquantizedFusedMoEMethod.create_weights(self, layer, **extra_weight_attrs) |
| 1929 | + |
| 1930 | + def process_weights_after_loading(self, layer: nn.Layer): |
| 1931 | + from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import UnquantizedFusedMoEMethod |
| 1932 | + |
| 1933 | + UnquantizedFusedMoEMethod.process_weights_after_loading(self, layer) |
| 1934 | + |
| 1935 | + def process_loaded_weights(self, layer: nn.Layer, state_dict): |
| 1936 | + """Stack individual expert weights into the stacked parameter.""" |
| 1937 | + up_gate_proj_weights, down_proj_weights, _, _ = layer.extract_moe_ffn_weights(state_dict) |
| 1938 | + layer.up_gate_proj_weight.set_value(paddle.stack(up_gate_proj_weights, axis=0)) |
| 1939 | + layer.down_proj_weight.set_value(paddle.stack(down_proj_weights, axis=0)) |
| 1940 | + |
| 1941 | + def _get_default_config(self, M: int, N: int, K: int) -> dict: |
| 1942 | + """ |
| 1943 | + Heuristic tile config for BF16 MoE, mirroring vLLM's get_default_config logic. |
| 1944 | + M: number of token-expert pairs (post-padded) / BLOCK_SIZE_M |
| 1945 | + N: output dimension of the GEMM |
| 1946 | + K: input dimension of the GEMM |
| 1947 | + """ |
| 1948 | + if M <= 32: |
| 1949 | + block_m, block_n, block_k = 16, 64, 64 |
| 1950 | + elif M <= 512: |
| 1951 | + block_m, block_n, block_k = 32, 128, 64 |
| 1952 | + else: |
| 1953 | + block_m, block_n, block_k = 128, 128, 64 |
| 1954 | + return { |
| 1955 | + "BLOCK_SIZE_M": block_m, |
| 1956 | + "BLOCK_SIZE_N": block_n, |
| 1957 | + "BLOCK_SIZE_K": block_k, |
| 1958 | + "GROUP_SIZE_M": 8, |
| 1959 | + } |
| 1960 | + |
| 1961 | + def apply( |
| 1962 | + self, |
| 1963 | + layer: nn.Layer, |
| 1964 | + x: paddle.Tensor, |
| 1965 | + gate: nn.Layer, |
| 1966 | + topk_ids_hookfunc: Callable = None, |
| 1967 | + shared_experts: nn.Layer = None, |
| 1968 | + ) -> paddle.Tensor: |
| 1969 | + """ |
| 1970 | + BF16 Triton Fused MoE forward. |
| 1971 | +
|
| 1972 | + Pipeline: |
| 1973 | + 1. Gate + topk routing |
| 1974 | + 2. tritonmoe_preprocess -> sorted_token_ids, expert_ids, num_tokens_post_padded |
| 1975 | + 3. fused_moe_kernel_paddle GEMM1: [tokens*topk, K] x [E, K, 2N] -> [tokens*topk, 2N] |
| 1976 | + 4. SwiGLU activation |
| 1977 | + 5. fused_moe_kernel_paddle GEMM2: [tokens*topk, N] x [E, N, K] -> [tokens*topk, K] |
| 1978 | + (with MUL_ROUTED_WEIGHT=True to fuse router weight multiplication) |
| 1979 | + 6. Reshape + sum over topk dim |
| 1980 | + """ |
| 1981 | + import fastdeploy |
| 1982 | + |
| 1983 | + token_num = x.shape[0] |
| 1984 | + if token_num == 0: |
| 1985 | + return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype) |
| 1986 | + |
| 1987 | + top_k = layer.top_k |
| 1988 | + num_local_experts = layer.num_local_experts |
| 1989 | + moe_intermediate_size = layer.moe_intermediate_size |
| 1990 | + hidden_size = layer.hidden_size |
| 1991 | + |
| 1992 | + # --- 1. Routing --- |
| 1993 | + gate_out = gate(x) |
| 1994 | + gate_out = gate_out.cast("float32") |
| 1995 | + |
| 1996 | + if layer.topk_method == "noaux_tc": |
| 1997 | + from fastdeploy.model_executor.layers.moe.moe import get_moe_scores |
| 1998 | + |
| 1999 | + _, topk_weights, topk_ids = get_moe_scores( |
| 2000 | + gate_out, |
| 2001 | + layer.n_group, |
| 2002 | + layer.topk_group, |
| 2003 | + top_k, |
| 2004 | + layer.routed_scaling_factor, |
| 2005 | + layer.gate_correction_bias, |
| 2006 | + getattr(layer, "renormalize", True), |
| 2007 | + topk_reduce_func=getattr(layer, "topk_reduce_func", None), |
| 2008 | + ) |
| 2009 | + else: |
| 2010 | + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( |
| 2011 | + gate_out, |
| 2012 | + layer.gate_correction_bias, |
| 2013 | + top_k, |
| 2014 | + True, # apply_norm_weight |
| 2015 | + False, |
| 2016 | + ) |
| 2017 | + |
| 2018 | + if topk_ids_hookfunc is not None: |
| 2019 | + topk_ids_hookfunc(topk_ids=topk_ids) |
| 2020 | + |
| 2021 | + # --- 2. Preprocess: sort tokens by expert assignment --- |
| 2022 | + # Choose BLOCK_SIZE_M based on decode vs prefill heuristic |
| 2023 | + num_token_expert_pairs = token_num * top_k |
| 2024 | + cfg = self._get_default_config(num_token_expert_pairs, moe_intermediate_size * 2, hidden_size) |
| 2025 | + |
| 2026 | + sorted_token_ids, expert_ids, num_tokens_post_padded = tritonmoe_preprocess_func( |
| 2027 | + topk_ids, num_local_experts, cfg["BLOCK_SIZE_M"] |
| 2028 | + ) |
| 2029 | + max_possible_num_post_padded = sorted_token_ids.shape[0] |
| 2030 | + |
| 2031 | + # --- 3. GEMM1: hidden -> up_gate (BF16 x BF16 -> BF16) --- |
| 2032 | + # up_gate_proj_weight layout: [E, hidden_size, inter*2] => stride_be, stride_bk, stride_bn |
| 2033 | + up_gate_proj_out = paddle.empty( |
| 2034 | + [num_token_expert_pairs, moe_intermediate_size * 2], |
| 2035 | + dtype=x.dtype, |
| 2036 | + ) |
| 2037 | + grid1 = ( |
| 2038 | + ceil_div(max_possible_num_post_padded, cfg["BLOCK_SIZE_M"]) |
| 2039 | + * ceil_div(moe_intermediate_size * 2, cfg["BLOCK_SIZE_N"]), |
| 2040 | + ) |
| 2041 | + fused_moe_kernel_bf16[grid1]( |
| 2042 | + x, |
| 2043 | + layer.up_gate_proj_weight, |
| 2044 | + up_gate_proj_out, |
| 2045 | + None, # topk_weights_ptr (no weight mul on GEMM1) |
| 2046 | + sorted_token_ids, |
| 2047 | + expert_ids, |
| 2048 | + num_tokens_post_padded, |
| 2049 | + N=moe_intermediate_size * 2, |
| 2050 | + K=hidden_size, |
| 2051 | + EM=max_possible_num_post_padded, |
| 2052 | + num_valid_tokens=num_token_expert_pairs, |
| 2053 | + stride_am=x.strides[0], |
| 2054 | + stride_ak=x.strides[1], |
| 2055 | + stride_be=layer.up_gate_proj_weight.strides[0], |
| 2056 | + stride_bk=layer.up_gate_proj_weight.strides[1], |
| 2057 | + stride_bn=layer.up_gate_proj_weight.strides[2], |
| 2058 | + stride_cm=up_gate_proj_out.strides[0], |
| 2059 | + stride_cn=up_gate_proj_out.strides[1], |
| 2060 | + BLOCK_SIZE_M=cfg["BLOCK_SIZE_M"], |
| 2061 | + BLOCK_SIZE_N=cfg["BLOCK_SIZE_N"], |
| 2062 | + BLOCK_SIZE_K=cfg["BLOCK_SIZE_K"], |
| 2063 | + GROUP_SIZE_M=cfg["GROUP_SIZE_M"], |
| 2064 | + MUL_ROUTED_WEIGHT=False, |
| 2065 | + top_k=top_k, |
| 2066 | + compute_type=tl.bfloat16, |
| 2067 | + ) |
| 2068 | + |
| 2069 | + # --- 4. SwiGLU activation --- |
| 2070 | + down_proj_input = paddle.incubate.nn.functional.swiglu(up_gate_proj_out) |
| 2071 | + |
| 2072 | + # --- 5. GEMM2: inter -> hidden, fuse router weight multiplication --- |
| 2073 | + # down_proj_weight layout: [E, moe_intermediate_size, hidden_size] => stride_be, stride_bk, stride_bn |
| 2074 | + down_proj_out = paddle.empty( |
| 2075 | + (num_token_expert_pairs, hidden_size), |
| 2076 | + dtype=x.dtype, |
| 2077 | + ) |
| 2078 | + cfg2 = self._get_default_config(num_token_expert_pairs, hidden_size, moe_intermediate_size) |
| 2079 | + grid2 = ( |
| 2080 | + ceil_div(max_possible_num_post_padded, cfg2["BLOCK_SIZE_M"]) |
| 2081 | + * ceil_div(hidden_size, cfg2["BLOCK_SIZE_N"]), |
| 2082 | + ) |
| 2083 | + fused_moe_kernel_bf16[grid2]( |
| 2084 | + down_proj_input, |
| 2085 | + layer.down_proj_weight, |
| 2086 | + down_proj_out, |
| 2087 | + topk_weights, |
| 2088 | + sorted_token_ids, |
| 2089 | + expert_ids, |
| 2090 | + num_tokens_post_padded, |
| 2091 | + N=hidden_size, |
| 2092 | + K=moe_intermediate_size, |
| 2093 | + EM=max_possible_num_post_padded, |
| 2094 | + num_valid_tokens=num_token_expert_pairs, |
| 2095 | + stride_am=down_proj_input.strides[0], |
| 2096 | + stride_ak=down_proj_input.strides[1], |
| 2097 | + stride_be=layer.down_proj_weight.strides[0], |
| 2098 | + stride_bk=layer.down_proj_weight.strides[1], |
| 2099 | + stride_bn=layer.down_proj_weight.strides[2], |
| 2100 | + stride_cm=down_proj_out.strides[0], |
| 2101 | + stride_cn=down_proj_out.strides[1], |
| 2102 | + BLOCK_SIZE_M=cfg2["BLOCK_SIZE_M"], |
| 2103 | + BLOCK_SIZE_N=cfg2["BLOCK_SIZE_N"], |
| 2104 | + BLOCK_SIZE_K=cfg2["BLOCK_SIZE_K"], |
| 2105 | + GROUP_SIZE_M=cfg2["GROUP_SIZE_M"], |
| 2106 | + MUL_ROUTED_WEIGHT=True, # fuse router weight * output |
| 2107 | + # top_k=1: down_proj_input rows are indexed directly by sorted_token_ids, |
| 2108 | + # so a_ptrs = base + offs_token * stride_am (no // top_k needed). |
| 2109 | + top_k=1, |
| 2110 | + compute_type=tl.bfloat16, |
| 2111 | + ) |
| 2112 | + |
| 2113 | + # --- 6. Reduce over topk --- |
| 2114 | + down_proj_out.reshape_([token_num, top_k, hidden_size]) |
| 2115 | + out = down_proj_out.sum(axis=1) |
| 2116 | + return out |
| 2117 | + |
| 2118 | + def apply_ep_prefill(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): |
| 2119 | + raise NotImplementedError("TritonBF16MoEMethod does not support EP prefill yet.") |
| 2120 | + |
| 2121 | + def apply_ep_decode(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): |
| 2122 | + raise NotImplementedError("TritonBF16MoEMethod does not support EP decode yet.") |
| 2123 | + |
| 2124 | + def apply_tp(self, layer, x, gate, topk_ids_hookfunc=None, shared_experts=None): |
| 2125 | + return self.apply(layer, x, gate, topk_ids_hookfunc=topk_ids_hookfunc) |
0 commit comments