Skip to content

Commit abdea90

Browse files
committed
update non-fused MQA attention
1 parent ab4fc91 commit abdea90

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

tutel/custom/custom_kernel.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,9 +1188,15 @@ torch::Tensor warp_deepseek_r1_attn_f16xf8_block_scal(
11881188
key_cache.narrow(0, pos, 1).copy_(std::get<1>(inputs).permute({1, 0, 2}));
11891189

11901190
auto Q = std::get<0>(inputs), C = key_cache.narrow(0, 0, pos + seqlen); // S2, B, (512 + 64)
1191-
auto scores = at::einsum("bshc,tbc->bsht", {Q, C}) * 0.1352337788608801f;
1192-
Q = at::einsum("bsht,tbc->bshc", {at::softmax(scores, -1), C}).narrow(-1, 0, 512);
1193-
Q = at::einsum("bshc,hdc->bshd", {Q, wvc}).contiguous();
1191+
if (batch == 1 && seqlen == 1) {
1192+
auto scores = antares::ops::call("scores_bf16", {Q.view({n_heads, Q.size(-1)}).view(torch::kInt32), C.squeeze(1).view(torch::kInt32)}, {0.1352337788608801f});
1193+
Q = torch::matmul(at::softmax(scores, -1), C.squeeze(1));
1194+
Q = antares::ops::call("logits_bf16", {Q.view(torch::kInt32), wvc.view(torch::kInt32)}, {});
1195+
} else {
1196+
auto scores = at::einsum("bshc,tbc->bsht", {Q, C}) * 0.1352337788608801f;
1197+
Q = at::einsum("bsht,tbc->bshc", {at::softmax(scores, -1), C}).narrow(-1, 0, 512);
1198+
Q = at::einsum("bshc,hdc->bshd", {Q, wvc}).contiguous();
1199+
}
11941200
Q = warp_gemm_nt_bf16xfp8_block_scal(Q.view({batch, seqlen, -1}), o_proj, o_proj_scal);
11951201
return Q;
11961202
}

0 commit comments

Comments
 (0)