Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
#endif
}

extern void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

enum ActivationKind : int64_t {
SwiGLU_Gu = 0, // act = SiLU(g) * u
SwiGLUOAI = 1, // act = SiLU(u) * g
Expand Down Expand Up @@ -87,30 +89,23 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;

// Per-expert outputs filled in parallel
std::vector<torch::Tensor> y_list(E);
y_list.resize(E);
auto X_all = x_c.index_select(/*dim=*/0, expert_tokens);
if (apply_router_weight_on_input) {
X_all = X_all.mul(expert_gates.unsqueeze(1));
}
auto Y_all = at::empty({offsets[E], H}, x_c.options());

at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
c10::InferenceMode guard;
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
y_list[e] = at::empty({0, H}, x_c.options());
continue;
}

const int64_t start = offsets[e];

auto sel_tokens =
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto gates_e =
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);

auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);

if (apply_router_weight_on_input) {
x_e = x_e.mul(gates_e.unsqueeze(1));
}
auto x_e = X_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);

auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
Expand All @@ -119,35 +114,33 @@ torch::Tensor dynamic_4bit_int_moe_cpu(
auto y13 =
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);

auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);

torch::Tensor act;
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
constexpr double kAlpha = 1.702; // GPT-OSS default
constexpr double kLimit = 7.0; // GPT-OSS default
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
constexpr double kAlpha = 1.702; // GPT-OSS default
constexpr double kLimit = 7.0; // GPT-OSS default
auto gate_c = at::clamp_max(g_part, kLimit);
auto up_c = at::clamp(u_part, -kLimit, kLimit);
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
act = up_c.add(1.0).mul(glu);
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
act = at::silu(g_part).mul(u_part);
act = at::empty({te, I}, y13.options());
silu_and_mul(act, y13);
}

// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);

if (!apply_router_weight_on_input) {
y = y.mul(gates_e.unsqueeze(1));
}

// Store per-expert result
y_list[e] = y;
Y_all.narrow(/*dim=*/0, /*start=*/start, /*length=*/te).copy_(y);
}
});

// Concatenate all expert outputs to match expert_tokens order
auto Y_all = at::cat(y_list, /*dim=*/0);
if (!apply_router_weight_on_input) {
Y_all = Y_all.mul(expert_gates.unsqueeze(1));
}

auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
Expand Down