@@ -1126,9 +1126,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> warp_multi_head_latent_r
11261126 CHECK_EQ (w_q_b_proj.dim (), 4 );
11271127 CHECK_EQ (k_b_proj.dtype (), torch::kBFloat16 );
11281128 CHECK_EQ (k_b_proj.dim (), 3 );
1129+ CHECK_CONTIGUOUS (k_b_proj.transpose (1 , 2 ));
11291130
1130- auto qh = torch::matmul (w_q_b_proj, q.view ({batch * seqlen, q.size (2 )}).t ()).view ({2 , n_local_heads, 128 , batch * seqlen});
1131- auto q_output = torch::matmul (qh[0 ].transpose (1 , 2 ), k_b_proj);
1131+ auto qh = ((batch * seqlen <= 4 ) ? \
1132+ antares::ops::call (" gmv_bf16" , {w_q_b_proj.view ({-1 , w_q_b_proj.size (-1 )}).view (torch::kInt32 ), q.view ({batch * seqlen, q.size (2 )}).view (torch::kInt32 )}, {}) : \
1133+ torch::matmul (w_q_b_proj, q.view ({batch * seqlen, q.size (2 )}).t ())).view ({2 , n_local_heads, 128 , batch * seqlen});
1134+ auto q_output = (batch * seqlen == 1 ) ? \
1135+ antares::ops::call (" bmv_bf16" , {qh[0 ].squeeze (-1 ).view (torch::kInt32 ), k_b_proj.transpose (1 , 2 ).view (torch::kInt32 )}, {}).unsqueeze (1 ) : (batch * seqlen <= 4 ? \
1136+ antares::ops::call (" bmm_bf16" , {qh[0 ], k_b_proj.transpose (1 , 2 )}, {}) : torch::matmul (qh[0 ].transpose (1 , 2 ), k_b_proj));
11321137
11331138 q_output = antares::ops::call (" rope_q_bf16" , {q_output.view ({n_local_heads, batch * seqlen, 16 , 32 }), cos_sin, qh[1 ].view ({n_local_heads, 2 , 32 , 2 , batch * seqlen}), positions}, {}).view ({batch, seqlen, n_local_heads, 576 });
11341139 return {q_output, k_output, kv};
@@ -1171,19 +1176,27 @@ torch::Tensor warp_deepseek_r1_attn_f16xf8_block_scal(
11711176 if (it2 == wkv_b_.end ()) {
11721177 auto _ = antares::ops::call (" to_bfloat16_3d" , {kv_b_proj.unsqueeze (0 ), kv_b_proj_scal.unsqueeze (0 )}, {}).
11731178 view ({n_local_heads, 2 , -1 , kv_b_proj.size (-1 )}).permute ({1 , 0 , 2 , 3 }).contiguous (); // 2, H, 128, 512
1174- wkv_b_[kv_b_proj_scal.data_ptr ()] = {_.select (0 , 0 ), _. select ( 0 , 1 ). transpose (1 , 2 )}; // H, D(128), C(512)
1179+ wkv_b_[kv_b_proj_scal.data_ptr ()] = {_.select (0 , 0 ). transpose ( 1 , 2 ). contiguous (). transpose (1 , 2 ), _. select ( 0 , 1 )}; // H, D(128), C(512)
11751180 it2 = wkv_b_.find (kv_b_proj_scal.data_ptr ());
11761181 }
11771182 auto wkc = std::get<0 >(it2->second ), wvc = std::get<1 >(it2->second );
11781183 auto qkv = warp_gemm_nt_bf16xfp8_block_scal (data, qkv_a_proj, qkv_a_proj_scal); // [B, S, 1536 + 512 + 64]
11791184
1180- auto positions = torch::full ({batch}, pos, torch::TensorOptions ().dtype (torch::kInt64 ).device (data.device ()));
1185+ static torch::Tensor posperm = torch::arange (0 , cos_sin.size (0 ), torch::TensorOptions ().dtype (torch::kInt64 ).device (data.device ()));
1186+ auto positions = batch == 1 ? posperm.narrow (0 , pos, 1 ) : torch::full ({batch}, pos, torch::TensorOptions ().dtype (torch::kInt64 ).device (data.device ()));
11811187 auto inputs = warp_multi_head_latent_rope_bf16 (qkv, cos_sin, positions, q_a_norm, kv_a_norm, it->second , wkc, n_local_heads);
11821188 key_cache.narrow (0 , pos, 1 ).copy_ (std::get<1 >(inputs).permute ({1 , 0 , 2 }));
1189+
11831190 auto Q = std::get<0 >(inputs), C = key_cache.narrow (0 , 0 , pos + seqlen); // S2, B, (512 + 64)
1184- auto scores = at::einsum (" bshc,tbc->bsht" , {Q, C}) * 0 .1352337788608801f ;
1185- Q = at::einsum (" bsht,tbc->bshc" , {at::softmax (scores, -1 ), C}).narrow (-1 , 0 , 512 );
1186- Q = at::einsum (" bshc,hcd->bshd" , {Q, wvc}).contiguous ();
1191+ if (batch == 1 && seqlen == 1 ) {
1192+ auto scores = antares::ops::call (" scores_bf16" , {Q.view ({n_heads, Q.size (-1 )}).view (torch::kInt32 ), C.squeeze (1 ).view (torch::kInt32 )}, {0 .1352337788608801f });
1193+ Q = torch::matmul (at::softmax (scores, -1 ), C.squeeze (1 ));
1194+ Q = antares::ops::call (" logits_bf16" , {Q.view (torch::kInt32 ), wvc.view (torch::kInt32 )}, {});
1195+ } else {
1196+ auto scores = at::einsum (" bshc,tbc->bsht" , {Q, C}) * 0 .1352337788608801f ;
1197+ Q = at::einsum (" bsht,tbc->bshc" , {at::softmax (scores, -1 ), C}).narrow (-1 , 0 , 512 );
1198+ Q = at::einsum (" bshc,hdc->bshd" , {Q, wvc}).contiguous ();
1199+ }
11871200 Q = warp_gemm_nt_bf16xfp8_block_scal (Q.view ({batch, seqlen, -1 }), o_proj, o_proj_scal);
11881201 return Q;
11891202 }
@@ -1200,48 +1213,22 @@ torch::Tensor warp_deepseek_r1_attn_f16xf8_block_scal(
12001213 auto q_pe_out = torch::empty_like (q_pe);
12011214 antares::ops::call (" rotary_lookup_bf16" , {cos_sin.select (0 , pos).select (0 , 0 ), cos_sin.select (0 , pos).select (0 , 1 ), q_pe.view ({-1 , 32 , 2 }), q_pe_out.view ({-1 , 2 , 32 })}, {}, false , 0 , 3 );
12021215
1203- if (val_cache.numel () > 1 ) {
1204- kv = warp_gemm_nt_bf16xfp8_block_scal (warp_rmsnorm_bf16 (kv, kv_a_norm, 1e-6f ), kv_b_proj, kv_b_proj_scal);
1205-
1206- antares::ops::call (" cache_fill_bf16" , {q_pe_out, k_pe_out, query_states, key_cache.select (0 , pos)}, {128 }, false , 0 , 3 );
1216+ kv = warp_gemm_nt_bf16xfp8_block_scal (warp_rmsnorm_bf16 (kv, kv_a_norm, 1e-6f ), kv_b_proj, kv_b_proj_scal);
1217+ antares::ops::call (" cache_fill_bf16" , {q_pe_out, k_pe_out, query_states, key_cache.select (0 , pos)}, {128 }, false , 0 , 3 );
12071218 // [B,S,H,64] [B,S,64] [B,S,H,128:] [B,H,128:]
12081219
1209- antares::ops::call (" cache_move_bf16" , {kv.view ({batch, seqlen, n_heads, 2 , 128 }), key_cache.narrow (0 , pos, seqlen), val_cache.narrow (0 , pos, seqlen)}, {}, false , 0 , 2 );
1220+ antares::ops::call (" cache_move_bf16" , {kv.view ({batch, seqlen, n_heads, 2 , 128 }), key_cache.narrow (0 , pos, seqlen), val_cache.narrow (0 , pos, seqlen)}, {}, false , 0 , 2 );
12101221 // [B,S,H,2,M] [S,B,H,:128] [S,B,H,:128]
12111222
1212- auto key_states = key_cache.narrow (0 , 0 , pos + seqlen).view ({1 , pos + seqlen, batch * n_heads, 192 });
1213- auto value_states = val_cache.narrow (0 , 0 , pos + seqlen).view ({1 , pos + seqlen, batch * n_heads, 128 });
1214- query_states = query_states.permute ({1 , 0 , 2 , 3 }).view ({1 , seqlen, -1 , 192 });
1215- CHECK_EQ (query_states.size (1 ), 1 );
1223+ auto key_states = key_cache.narrow (0 , 0 , pos + seqlen).view ({1 , pos + seqlen, batch * n_heads, 192 });
1224+ auto value_states = val_cache.narrow (0 , 0 , pos + seqlen).view ({1 , pos + seqlen, batch * n_heads, 128 });
1225+ query_states = query_states.permute ({1 , 0 , 2 , 3 }).view ({1 , seqlen, -1 , 192 });
1226+ CHECK_EQ (query_states.size (1 ), 1 );
12161227
1217- auto lm = torch::empty ({2 , batch * n_heads, 64 }, torch::TensorOptions ().dtype (torch::kBFloat16 ).device (query_states.device ()));
1218- auto attn_output = antares::ops::call (" self_attn_infer_bf16" , {query_states.squeeze (0 ).squeeze (0 ), key_states.squeeze (0 ), value_states.squeeze (0 ), lm}, {0 .1352337788608801f });
1219- xb = torch::matmul (antares::ops::call (" self_attn_reduce_bf16" , {lm}, {}).unsqueeze (1 ), attn_output).to (query_states.dtype ());
1220-
1221- // xb = std::get<0>(at::native::_scaled_dot_product_attention_math(query_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), key_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), value_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), {}, 0, false, {}, 0.1352337788608801)).permute({0, 2, 1, 3}).to(query_states.dtype());
1222- } else {
1223- kv = torch::cat ({warp_rmsnorm_bf16 (kv, kv_a_norm, 1e-6f ), k_pe_out}, -1 ); // [B, S, 512]
1224- key_cache.narrow (0 , pos, seqlen).narrow (1 , 0 , batch) = kv.permute ({1 , 0 , 2 }); // [S, B, 512 + 64]
1225-
1226- static std::unordered_map<void *, torch::Tensor> wkv_b_;
1227- auto it = wkv_b_.find (kv_b_proj_scal.data_ptr ());
1228- if (it == wkv_b_.end ()) {
1229- wkv_b_[kv_b_proj_scal.data_ptr ()] = antares::ops::call (" to_bfloat16_3d" , {kv_b_proj.unsqueeze (0 ), kv_b_proj_scal.unsqueeze (0 )}, {}).
1230- view ({n_heads, 2 , -1 , kv_b_proj.size (-1 )}).permute ({1 , 0 , 2 , 3 }).contiguous (); // 2, H, 128, 512
1231- it = wkv_b_.find (kv_b_proj_scal.data_ptr ());
1232- }
1233- auto _0 = it->second .select (0 , 0 ), _1 = it->second .select (0 , 1 ); // H, D(128), C(512)
1234- // k_pe_out, q_pe_out -- 1, 1, 64 | 1, 1, 16, 64
1235- auto q_nope = query_states.narrow (-1 , 0 , 128 ).contiguous (); // B, S, H, D
1236- q_nope = at::einsum (" bshd,hdc->bshc" , {q_nope, _0}).contiguous (); // B, S, H, C(512)
1237- q_nope = torch::cat ({q_nope, q_pe_out}, -1 );
1238-
1239- auto R = key_cache.narrow (0 , 0 , pos + seqlen); // S2, B, (512 + 64)
1240- auto scores_ = at::einsum (" bshc,tbc->bsht" , {q_nope, R}) * 0 .1352337788608801f ;
1241- _0 = at::einsum (" bsht,tbc->bshc" , {at::softmax (scores_, -1 ), R});
1242- _0 = _0.narrow (-1 , 0 , 512 ).contiguous ();
1243- xb = at::einsum (" bshc,hdc->bshd" , {_0, _1}).contiguous ();
1244- }
1228+ auto lm = torch::empty ({2 , batch * n_heads, 64 }, torch::TensorOptions ().dtype (torch::kBFloat16 ).device (query_states.device ()));
1229+ auto attn_output = antares::ops::call (" self_attn_infer_bf16" , {query_states.squeeze (0 ).squeeze (0 ), key_states.squeeze (0 ), value_states.squeeze (0 ), lm}, {0 .1352337788608801f });
1230+ xb = torch::matmul (antares::ops::call (" self_attn_reduce_bf16" , {lm}, {}).unsqueeze (1 ), attn_output).to (query_states.dtype ());
1231+ // xb = std::get<0>(at::native::_scaled_dot_product_attention_math(query_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), key_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), value_states.permute({0, 2, 1, 3}).to(torch::kBFloat16), {}, 0, false, {}, 0.1352337788608801)).permute({0, 2, 1, 3}).to(query_states.dtype());
12451232 xb = warp_gemm_nt_bf16xfp8_block_scal (xb.view ({batch, seqlen, -1 }), o_proj, o_proj_scal);
12461233 return xb;
12471234}
@@ -1405,7 +1392,7 @@ void warp_deepseek_r1_prepare_weights(
14051392 ::cos_sin = cos_sin;
14061393
14071394 int n_layers = o_projs.size ();
1408- bool use_lora = getenv (" LORA" ) ? (std::atoi (getenv (" LORA" )) == 1 ) : (batch >= 4 ) ;
1395+ bool use_lora = getenv (" LORA" ) ? (std::atoi (getenv (" LORA" )) == 1 ) : true ;
14091396 if (use_lora) {
14101397 // kv_lora_rank + qk_rope_head_dim
14111398 ::key_cache = torch::zeros ({n_layers, max_seq_len, batch, 512 + 64 }, torch::TensorOptions ().dtype (token_emb.dtype ()).device (token_emb.device ()));
0 commit comments