File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed
Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -89,15 +89,14 @@ def fused_experts_with_mc2(
8989 0 :5 ]
9090
9191 w1 = w1 .transpose (1 , 2 )
92- expert_token_nums = torch .cumsum (expert_token_nums ,
93- dim = 0 ,
94- dtype = torch .int64 )
92+
9593 group_list = expert_token_nums .to (torch .int64 )
9694 gate_up_out_list = torch_npu .npu_grouped_matmul (
9795 x = [expand_x ],
9896 weight = [w1 ],
9997 split_item = 2 ,
100- group_list_type = 0 ,
98+ # 1 means count mode, to avoid cumulative operation of the group list
99+ group_list_type = 1 ,
101100 group_type = 0 ,
102101 group_list = group_list ,
103102 )
@@ -111,7 +110,7 @@ def fused_experts_with_mc2(
111110 x = [gate_up_out ],
112111 weight = [w2 ],
113112 split_item = 2 ,
114- group_list_type = 0 ,
113+ group_list_type = 1 ,
115114 group_type = 0 ,
116115 group_list = group_list ,
117116 )
You can’t perform that action at this time.
0 commit comments