@@ -1188,9 +1188,15 @@ torch::Tensor warp_deepseek_r1_attn_f16xf8_block_scal(
11881188 key_cache.narrow (0 , pos, 1 ).copy_ (std::get<1 >(inputs).permute ({1 , 0 , 2 }));
11891189
11901190 auto Q = std::get<0 >(inputs), C = key_cache.narrow (0 , 0 , pos + seqlen); // S2, B, (512 + 64)
1191- auto scores = at::einsum (" bshc,tbc->bsht" , {Q, C}) * 0 .1352337788608801f ;
1192- Q = at::einsum (" bsht,tbc->bshc" , {at::softmax (scores, -1 ), C}).narrow (-1 , 0 , 512 );
1193- Q = at::einsum (" bshc,hdc->bshd" , {Q, wvc}).contiguous ();
1191+ if (batch == 1 && seqlen == 1 ) {
1192+ auto scores = antares::ops::call (" scores_bf16" , {Q.view ({n_heads, Q.size (-1 )}).view (torch::kInt32 ), C.squeeze (1 ).view (torch::kInt32 )}, {0 .1352337788608801f });
1193+ Q = torch::matmul (at::softmax (scores, -1 ), C.squeeze (1 ));
1194+ Q = antares::ops::call (" logits_bf16" , {Q.view (torch::kInt32 ), wvc.view (torch::kInt32 )}, {});
1195+ } else {
1196+ auto scores = at::einsum (" bshc,tbc->bsht" , {Q, C}) * 0 .1352337788608801f ;
1197+ Q = at::einsum (" bsht,tbc->bshc" , {at::softmax (scores, -1 ), C}).narrow (-1 , 0 , 512 );
1198+ Q = at::einsum (" bshc,hdc->bshd" , {Q, wvc}).contiguous ();
1199+ }
11941200 Q = warp_gemm_nt_bf16xfp8_block_scal (Q.view ({batch, seqlen, -1 }), o_proj, o_proj_scal);
11951201 return Q;
11961202 }
0 commit comments