Skip to content

Commit ca35035

Browse files
Guanyu Chen (i26275)Tryorish
authored andcommitted
[Metax][FIX] fix ci error caused by pr#7428
1 parent 1334520 commit ca35035

1 file changed

Lines changed: 13 additions & 2 deletions

File tree

fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def apply_tp(
9797
x: paddle.Tensor,
9898
gate: nn.Layer,
9999
topk_ids_hookfunc: Callable = None,
100+
fc1_latent_proj: nn.Layer = None,
101+
fc2_latent_proj: nn.Layer = None,
100102
) -> paddle.Tensor:
101103
"""
102104
Paddle Cutlass compute Fused MoE.
@@ -150,13 +152,19 @@ def apply_tp(
150152
x: paddle.Tensor,
151153
gate: nn.Layer,
152154
topk_ids_hookfunc: Callable = None,
155+
fc1_latent_proj: nn.Layer = None,
156+
fc2_latent_proj: nn.Layer = None,
153157
) -> paddle.Tensor:
154158
"""
155159
Paddle Cutlass compute Fused MoE.
156160
"""
157-
if layer.topk_method == "noaux_tc":
158-
gate_out = gate(x.cast("float32"))
159161

162+
gate_out = gate(x.cast("float32"))
163+
164+
if fc1_latent_proj is not None:
165+
x = fc1_latent_proj(x)
166+
167+
if layer.topk_method == "noaux_tc":
160168
gate_out, topk_weights, topk_idx = get_moe_scores(
161169
gate_out,
162170
layer.n_group,
@@ -229,6 +237,9 @@ def apply_tp(
229237
False,
230238
)
231239

240+
if fc2_latent_proj is not None:
241+
fused_moe_out = fc2_latent_proj(fused_moe_out)
242+
232243
return fused_moe_out
233244

234245

0 commit comments

Comments
 (0)