Skip to content

Commit 1151fd1

Browse files
xiangze-armjuliendenize
authored andcommitted
[CPU]Improve dynamic 4bit moe performance (vllm-project#27240)
Signed-off-by: Zhang Xiangze <[email protected]>
1 parent d2f65d7 commit 1151fd1

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

csrc/moe/dynamic_4bit_int_moe_cpu.cpp

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -87,30 +87,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
8787
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
8888
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
8989

90-
// Per-expert outputs filled in parallel
91-
std::vector<torch::Tensor> y_list(E);
92-
y_list.resize(E);
90+
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
91+
if (apply_router_weight_on_input) {
92+
X_all = X_all.mul(expert_gates.unsqueeze(1));
93+
}
94+
auto Y_all = at::empty({offsets[E], H}, x_c.options());
9395

9496
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
97+
c10::InferenceMode guard;
9598
for (int64_t e = e_begin; e < e_end; ++e) {
9699
const int64_t te = counts[e];
97100
if (te == 0) {
98-
y_list[e] = at::empty({0, H}, x_c.options());
99101
continue;
100102
}
101103

102104
const int64_t start = offsets[e];
103105

104-
auto sel_tokens =
105-
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
106-
auto gates_e =
107-
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
108-
109-
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
110-
111-
if (apply_router_weight_on_input) {
112-
x_e = x_e.mul(gates_e.unsqueeze(1));
113-
}
106+
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
114107

115108
auto w13_e = w13_packed.select(/*dim=*/0, e);
116109
auto w2_e = w2_packed.select(/*dim=*/0, e);
@@ -137,17 +130,15 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
137130
// W2
138131
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
139132

140-
if (!apply_router_weight_on_input) {
141-
y = y.mul(gates_e.unsqueeze(1));
142-
}
143-
144133
// Store per-expert result
145-
y_list[e] = y;
134+
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
146135
}
147136
});
148137

149-
// Concatenate all expert outputs to match expert_tokens order
150-
auto Y_all = at::cat(y_list, /*dim=*/0);
138+
if (!apply_router_weight_on_input) {
139+
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
140+
}
141+
151142
auto out = at::zeros({T, H}, x.options());
152143
out =
153144
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);

0 commit comments

Comments
 (0)