File tree Expand file tree Collapse file tree
fastdeploy/model_executor/layers/backends/metax/moe Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -97,6 +97,8 @@ def apply_tp(
9797 x : paddle .Tensor ,
9898 gate : nn .Layer ,
9999 topk_ids_hookfunc : Callable = None ,
100+ fc1_latent_proj : nn .Layer = None ,
101+ fc2_latent_proj : nn .Layer = None ,
100102 ) -> paddle .Tensor :
101103 """
102104 Paddle Cutlass compute Fused MoE.
@@ -150,13 +152,19 @@ def apply_tp(
150152 x : paddle .Tensor ,
151153 gate : nn .Layer ,
152154 topk_ids_hookfunc : Callable = None ,
155+ fc1_latent_proj : nn .Layer = None ,
156+ fc2_latent_proj : nn .Layer = None ,
153157 ) -> paddle .Tensor :
154158 """
155159 Paddle Cutlass compute Fused MoE.
156160 """
157- if layer .topk_method == "noaux_tc" :
158- gate_out = gate (x .cast ("float32" ))
159161
162+ gate_out = gate (x .cast ("float32" ))
163+
164+ if fc1_latent_proj is not None :
165+ x = fc1_latent_proj (x )
166+
167+ if layer .topk_method == "noaux_tc" :
160168 gate_out , topk_weights , topk_idx = get_moe_scores (
161169 gate_out ,
162170 layer .n_group ,
@@ -229,6 +237,9 @@ def apply_tp(
229237 False ,
230238 )
231239
240+ if fc2_latent_proj is not None :
241+ fused_moe_out = fc2_latent_proj (fused_moe_out )
242+
232243 return fused_moe_out
233244
234245
You can’t perform that action at this time.
0 commit comments