1- #![ allow( clippy:: cast_precision_loss) ]
1+ #![ allow( clippy:: cast_possible_truncation , clippy :: cast_precision_loss) ]
22
33#[ cfg( feature = "metal" ) ]
44use std:: sync:: atomic:: AtomicUsize ;
5353 } else {
5454 return Err ( candle_core:: Error :: Msg ( "Expected CPU storage for v" . into ( ) ) ) ;
5555 } ;
56- let mask_guard = match mask {
57- Some ( mask) => Some ( mask. storage_and_layout ( ) . 0 ) ,
58- None => None ,
59- } ;
56+ let mask_guard = mask. map ( |mask| mask. storage_and_layout ( ) . 0 ) ;
6057 let mask_data: Option < & [ T ] > = if let Some ( mask_guard) = & mask_guard {
6158 let mask = mask. as_ref ( ) . unwrap ( ) ;
6259
8986 q. shape ( ) . dims ( ) ,
9087 k. shape ( ) . dims ( ) ,
9188 v. shape ( ) . dims ( ) ,
92- & q_stride,
93- & k_stride,
94- & v_stride,
89+ q_stride,
90+ k_stride,
91+ v_stride,
9592 sdpa_params. softmax_scale ,
9693 0.0 ,
9794 sdpa_params. softcap . unwrap_or ( 0.0 ) ,
@@ -106,9 +103,9 @@ where
106103 q. shape ( ) . dims ( ) ,
107104 k. shape ( ) . dims ( ) ,
108105 v. shape ( ) . dims ( ) ,
109- & q_stride,
110- & k_stride,
111- & v_stride,
106+ q_stride,
107+ k_stride,
108+ v_stride,
112109 sdpa_params. softmax_scale ,
113110 0.0 ,
114111 sdpa_params. softcap . unwrap_or ( 0.0 ) ,
@@ -135,14 +132,12 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
135132) -> Result < Tensor > {
136133 // Shapes: (B, 1, H, D)
137134 let ( b, _q_len, h, d) = (
138- qshape[ 0 ] as usize ,
139- qshape[ 1 ] as usize , // == 1
140- qshape[ 2 ] as usize ,
141- qshape[ 3 ] as usize ,
135+ qshape[ 0 ] , qshape[ 1 ] , // == 1
136+ qshape[ 2 ] , qshape[ 3 ] ,
142137 ) ;
143- let kv_len = kshape[ 1 ] as usize ;
144- let k_h = kshape[ 2 ] as usize ;
145- let v_h = vshape[ 2 ] as usize ;
138+ let kv_len = kshape[ 1 ] ;
139+ let k_h = kshape[ 2 ] ;
140+ let v_h = vshape[ 2 ] ;
146141 let rk2 = h / k_h;
147142 let rv2 = h / v_h;
148143 let dv = d;
@@ -156,7 +151,7 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
156151
157152 // Expose a second dimension of work: split the KV axis into tiles that
158153 // fit in the last‑level cache and let Rayon schedule them.
159- let kv_tiles = ( kv_len + TILE_KV - 1 ) / TILE_KV ;
154+ let kv_tiles = kv_len. div_ceil ( TILE_KV ) ;
160155
161156 // SAFETY: `par_chunks_mut` hands out non‑overlapping &mut slices, so no two
162157 // threads write the same output area.
@@ -169,10 +164,10 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
169164
170165 // Positional‑bias (same as before)
171166 let slope = if max_bias > 0.0 {
172- if ( h_i as u32 ) < n_head_log2 {
167+ if h_i < n_head_log2 as usize {
173168 m0. powi ( ( h_i + 1 ) as i32 )
174169 } else {
175- m1. powi ( ( 2 * ( h_i as i32 - n_head_log2 as i32 ) + 1 ) as i32 )
170+ m1. powi ( 2 * ( h_i as i32 - n_head_log2 as i32 ) + 1 )
176171 }
177172 } else {
178173 1.0
@@ -292,11 +287,12 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
292287 } ) ;
293288
294289 let out_shape = ( b, h, 1usize , dv) ;
295- Ok ( Tensor :: from_vec ( out, out_shape, & Device :: Cpu ) ? )
290+ Tensor :: from_vec ( out, out_shape, & Device :: Cpu )
296291}
297292
298293/// Main forward flash-attention CPU routine.
299294/// Shapes follow Candle convention: (B, S, H, D)
295+ #[ allow( clippy:: too_many_arguments) ]
300296pub fn flash_attn_cpu < T : WithDType + Sum + num_traits:: real:: Real > (
301297 q_data : & [ T ] ,
302298 k_data : & [ T ] ,
@@ -312,23 +308,13 @@ pub fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
312308 max_bias : f32 ,
313309 logit_softcap : f32 ,
314310) -> Result < Tensor > {
315- // Shapes: (B, S, H, D)
316- let qshape = qshape;
317- let kshape = kshape;
318- let vshape = vshape;
319-
320- let ( b, q_len, h, d) = (
321- qshape[ 0 ] as usize ,
322- qshape[ 1 ] as usize ,
323- qshape[ 2 ] as usize ,
324- qshape[ 3 ] as usize ,
325- ) ;
326- let kv_len = kshape[ 1 ] as usize ;
311+ let ( b, q_len, h, d) = ( qshape[ 0 ] , qshape[ 1 ] , qshape[ 2 ] , qshape[ 3 ] ) ;
312+ let kv_len = kshape[ 1 ] ;
327313 // --- Head broadcasting factors ----------------------------------------------------
328314 // Allows K and V to have fewer heads than Q (grouped‑KV); the ratio is an
329315 // integer factor. rk2 = #Q‑heads / #K‑heads, rv2 = #Q‑heads / #V‑heads.
330- let k_h = kshape[ 2 ] as usize ;
331- let v_h = vshape[ 2 ] as usize ;
316+ let k_h = kshape[ 2 ] ;
317+ let v_h = vshape[ 2 ] ;
332318 let rk2 = h / k_h; // must divide exactly; panic otherwise
333319 let rv2 = h / v_h;
334320 let dv = d; // value dim = key dim in this kernel
@@ -363,7 +349,7 @@ pub fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
363349 if ( h_i as u32 ) < n_head_log2 {
364350 m0. powi ( ( h_i + 1 ) as i32 )
365351 } else {
366- m1. powi ( ( 2 * ( h_i as i32 - n_head_log2 as i32 ) + 1 ) as i32 )
352+ m1. powi ( 2 * ( h_i as i32 - n_head_log2 as i32 ) + 1 )
367353 }
368354 } else {
369355 1.0
@@ -450,7 +436,7 @@ pub fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
450436
451437 // Build output tensor with shape (B, H, S, D) to match standard (permute 0,2,1,3)
452438 let out_shape = ( b, h, q_len, dv) ;
453- Ok ( Tensor :: from_vec ( out, out_shape, & Device :: Cpu ) ? )
439+ Tensor :: from_vec ( out, out_shape, & Device :: Cpu )
454440}
455441
456442#[ cfg( feature = "metal" ) ]
0 commit comments