File tree Expand file tree Collapse file tree 1 file changed +19
-23
lines changed Expand file tree Collapse file tree 1 file changed +19
-23
lines changed Original file line number Diff line number Diff line change @@ -386,34 +386,30 @@ impl Sdpa {
386386
387387 // Batch matrix multiplication
388388 // Fuse softmax scale and attention_bias add
389- let mut attention_scores = cublaslt
390- . batch_matmul (
391- & k,
392- & q,
393- attention_bias. as_ref ( ) ,
394- Some ( sdpa_params. softmax_scale / sdpa_params. softcap . unwrap_or ( 1.0 ) ) ,
395- beta,
396- None ,
397- None ,
398- )
399- . unwrap ( ) ;
389+ let mut attention_scores = cublaslt. batch_matmul (
390+ & k,
391+ & q,
392+ attention_bias. as_ref ( ) ,
393+ Some ( sdpa_params. softmax_scale / sdpa_params. softcap . unwrap_or ( 1.0 ) ) ,
394+ beta,
395+ None ,
396+ None ,
397+ ) ?;
400398 if let Some ( softcap) = sdpa_params. softcap {
401399 attention_scores = ( attention_scores. tanh ( ) ? * softcap as f64 ) ?;
402400 }
403401 candle_nn:: ops:: inplace_softmax_last_dim ( & mut attention_scores) ?;
404402
405- let context_layer = cublaslt
406- . batch_matmul (
407- & v. t ( ) ?. contiguous ( ) . unwrap ( ) ,
408- & attention_scores,
409- // We save one allocation
410- Some ( & q) ,
411- None ,
412- None ,
413- None ,
414- None ,
415- )
416- . unwrap ( ) ;
403+ let context_layer = cublaslt. batch_matmul (
404+ & v. t ( ) ?. contiguous ( ) ?,
405+ & attention_scores,
406+ // We save one allocation
407+ Some ( & q) ,
408+ None ,
409+ None ,
410+ None ,
411+ None ,
412+ ) ?;
417413
418414 // Reshape to dims4
419415 context_layer. reshape ( ( b_sz, n_attn_heads, seq_len, v_head_dim) )
You can’t perform that action at this time.
0 commit comments