@@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
8787 const int64_t g_eff_13 = (group_size != -1 ) ? group_size : H;
8888 const int64_t g_eff_2 = (group_size != -1 ) ? group_size : I;
8989
90- // Per-expert outputs filled in parallel
91- std::vector<torch::Tensor> y_list (E);
92- y_list.resize (E);
90+ auto X_all = x_c.index_select (/* dim=*/ 0 , expert_tokens);
91+ if (apply_router_weight_on_input) {
92+ X_all = X_all.mul (expert_gates.unsqueeze (1 ));
93+ }
94+ auto Y_all = at::empty ({offsets[E], H}, x_c.options ());
9395
9496 at::parallel_for (0 , E, 1 , [&](int64_t e_begin, int64_t e_end) {
97+ c10::InferenceMode guard;
9598 for (int64_t e = e_begin; e < e_end; ++e) {
9699 const int64_t te = counts[e];
97100 if (te == 0 ) {
98- y_list[e] = at::empty ({0 , H}, x_c.options ());
99101 continue ;
100102 }
101103
102104 const int64_t start = offsets[e];
103105
104- auto sel_tokens =
105- expert_tokens.narrow (/* dim=*/ 0 , /* start=*/ start, /* length=*/ te);
106- auto gates_e =
107- expert_gates.narrow (/* dim=*/ 0 , /* start=*/ start, /* length=*/ te);
108-
109- auto x_e = x_c.index_select (/* dim=*/ 0 , sel_tokens);
110-
111- if (apply_router_weight_on_input) {
112- x_e = x_e.mul (gates_e.unsqueeze (1 ));
113- }
106+ auto x_e = X_all.narrow (/* dim=*/ 0 , /* start=*/ start, /* length=*/ te);
114107
115108 auto w13_e = w13_packed.select (/* dim=*/ 0 , e);
116109 auto w2_e = w2_packed.select (/* dim=*/ 0 , e);
@@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
137130 // W2
138131 auto y = mm (act, w2_e, g_eff_2, /* in_features=*/ I, /* out_features=*/ H);
139132
140- if (!apply_router_weight_on_input) {
141- y = y.mul (gates_e.unsqueeze (1 ));
142- }
143-
144133 // Store per-expert result
145- y_list[e] = y ;
134+ Y_all. narrow ( /* dim= */ 0 , /* start= */ start, /* length= */ te). copy_ (y) ;
146135 }
147136 });
148137
149- // Concatenate all expert outputs to match expert_tokens order
150- auto Y_all = at::cat (y_list, /* dim=*/ 0 );
138+ if (!apply_router_weight_on_input) {
139+ Y_all = Y_all.mul (expert_gates.unsqueeze (1 ));
140+ }
141+
151142 auto out = at::zeros ({T, H}, x.options ());
152143 out =
153144 at::index_add (out, /* dim=*/ 0 , /* index=*/ expert_tokens, /* source=*/ Y_all);
0 commit comments