Skip to content

Commit f882400

Browse files
committed
Apply some clippy
1 parent 9a7a71e commit f882400

File tree

1 file changed

+24
-38
lines changed

1 file changed

+24
-38
lines changed

mistralrs-core/src/attention.rs

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#![allow(clippy::cast_precision_loss)]
1+
#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
22

33
#[cfg(feature = "metal")]
44
use std::sync::atomic::AtomicUsize;
@@ -53,10 +53,7 @@ where
5353
} else {
5454
return Err(candle_core::Error::Msg("Expected CPU storage for v".into()));
5555
};
56-
let mask_guard = match mask {
57-
Some(mask) => Some(mask.storage_and_layout().0),
58-
None => None,
59-
};
56+
let mask_guard = mask.map(|mask| mask.storage_and_layout().0);
6057
let mask_data: Option<&[T]> = if let Some(mask_guard) = &mask_guard {
6158
let mask = mask.as_ref().unwrap();
6259

@@ -89,9 +86,9 @@ where
8986
q.shape().dims(),
9087
k.shape().dims(),
9188
v.shape().dims(),
92-
&q_stride,
93-
&k_stride,
94-
&v_stride,
89+
q_stride,
90+
k_stride,
91+
v_stride,
9592
sdpa_params.softmax_scale,
9693
0.0,
9794
sdpa_params.softcap.unwrap_or(0.0),
@@ -106,9 +103,9 @@ where
106103
q.shape().dims(),
107104
k.shape().dims(),
108105
v.shape().dims(),
109-
&q_stride,
110-
&k_stride,
111-
&v_stride,
106+
q_stride,
107+
k_stride,
108+
v_stride,
112109
sdpa_params.softmax_scale,
113110
0.0,
114111
sdpa_params.softcap.unwrap_or(0.0),
@@ -135,14 +132,12 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
135132
) -> Result<Tensor> {
136133
// Shapes: (B, 1, H, D)
137134
let (b, _q_len, h, d) = (
138-
qshape[0] as usize,
139-
qshape[1] as usize, // == 1
140-
qshape[2] as usize,
141-
qshape[3] as usize,
135+
qshape[0], qshape[1], // == 1
136+
qshape[2], qshape[3],
142137
);
143-
let kv_len = kshape[1] as usize;
144-
let k_h = kshape[2] as usize;
145-
let v_h = vshape[2] as usize;
138+
let kv_len = kshape[1];
139+
let k_h = kshape[2];
140+
let v_h = vshape[2];
146141
let rk2 = h / k_h;
147142
let rv2 = h / v_h;
148143
let dv = d;
@@ -156,7 +151,7 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
156151

157152
// Expose a second dimension of work: split the KV axis into tiles that
158153
// fit in the last‑level cache and let Rayon schedule them.
159-
let kv_tiles = (kv_len + TILE_KV - 1) / TILE_KV;
154+
let kv_tiles = kv_len.div_ceil(TILE_KV);
160155

161156
// SAFETY: `par_chunks_mut` hands out non‑overlapping &mut slices, so no two
162157
// threads write the same output area.
@@ -169,10 +164,10 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
169164

170165
// Positional‑bias (same as before)
171166
let slope = if max_bias > 0.0 {
172-
if (h_i as u32) < n_head_log2 {
167+
if h_i < n_head_log2 as usize {
173168
m0.powi((h_i + 1) as i32)
174169
} else {
175-
m1.powi((2 * (h_i as i32 - n_head_log2 as i32) + 1) as i32)
170+
m1.powi(2 * (h_i as i32 - n_head_log2 as i32) + 1)
176171
}
177172
} else {
178173
1.0
@@ -292,11 +287,12 @@ fn flash_attn_cpu_single_q<T: WithDType + Sum + num_traits::real::Real>(
292287
});
293288

294289
let out_shape = (b, h, 1usize, dv);
295-
Ok(Tensor::from_vec(out, out_shape, &Device::Cpu)?)
290+
Tensor::from_vec(out, out_shape, &Device::Cpu)
296291
}
297292

298293
/// Main forward flash-attention CPU routine.
299294
/// Shapes follow Candle convention: (B, S, H, D)
295+
#[allow(clippy::too_many_arguments)]
300296
pub fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
301297
q_data: &[T],
302298
k_data: &[T],
@@ -312,23 +308,13 @@ pub fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
312308
max_bias: f32,
313309
logit_softcap: f32,
314310
) -> Result<Tensor> {
315-
// Shapes: (B, S, H, D)
316-
let qshape = qshape;
317-
let kshape = kshape;
318-
let vshape = vshape;
319-
320-
let (b, q_len, h, d) = (
321-
qshape[0] as usize,
322-
qshape[1] as usize,
323-
qshape[2] as usize,
324-
qshape[3] as usize,
325-
);
326-
let kv_len = kshape[1] as usize;
311+
let (b, q_len, h, d) = (qshape[0], qshape[1], qshape[2], qshape[3]);
312+
let kv_len = kshape[1];
327313
// --- Head broadcasting factors ----------------------------------------------------
328314
// Allows K and V to have fewer heads than Q (grouped‑KV); the ratio is an
329315
// integer factor. rk2 = #Q‑heads / #K‑heads, rv2 = #Q‑heads / #V‑heads.
330-
let k_h = kshape[2] as usize;
331-
let v_h = vshape[2] as usize;
316+
let k_h = kshape[2];
317+
let v_h = vshape[2];
332318
let rk2 = h / k_h; // must divide exactly; panic otherwise
333319
let rv2 = h / v_h;
334320
let dv = d; // value dim = key dim in this kernel
@@ -363,7 +349,7 @@ pub fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
363349
if (h_i as u32) < n_head_log2 {
364350
m0.powi((h_i + 1) as i32)
365351
} else {
366-
m1.powi((2 * (h_i as i32 - n_head_log2 as i32) + 1) as i32)
352+
m1.powi(2 * (h_i as i32 - n_head_log2 as i32) + 1)
367353
}
368354
} else {
369355
1.0
@@ -450,7 +436,7 @@ pub fn flash_attn_cpu<T: WithDType + Sum + num_traits::real::Real>(
450436

451437
// Build output tensor with shape (B, H, S, D) to match standard (permute 0,2,1,3)
452438
let out_shape = (b, h, q_len, dv);
453-
Ok(Tensor::from_vec(out, out_shape, &Device::Cpu)?)
439+
Tensor::from_vec(out, out_shape, &Device::Cpu)
454440
}
455441

456442
#[cfg(feature = "metal")]

0 commit comments

Comments
 (0)