Skip to content

Commit f5164df

Browse files
committed
fix black code stype
1 parent ca9a01e commit f5164df

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

aiter/mla.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,15 @@ def mla_decode_fwd(
196196
MAYBE_FINAL_OUT = False
197197

198198
logits = (
199-
o.view((total_s, num_kv_splits, nhead, v_head_dim)) if (num_kv_splits == 1
200-
and (q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4))) else
201-
torch.empty(
199+
o.view((total_s, num_kv_splits, nhead, v_head_dim))
200+
if (
201+
num_kv_splits == 1
202+
and (
203+
q.dtype == dtypes.fp8
204+
or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
205+
)
206+
)
207+
else torch.empty(
202208
(total_s, num_kv_splits, nhead, v_head_dim),
203209
dtype=dtypes.fp32,
204210
device=device,
@@ -230,7 +236,9 @@ def mla_decode_fwd(
230236
kv_scale,
231237
)
232238

233-
if (num_kv_splits == 1 and (q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4))):
239+
if num_kv_splits == 1 and (
240+
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
241+
):
234242
return logits.view(total_s, nhead, v_head_dim), attn_lse
235243

236244
Lv = v_head_dim

op_tests/test_mla_persistent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def test_mla(
277277
max_seqlen_qo=int(max_seqlen_qo),
278278
uni_seqlen_qo=decode_qlen,
279279
fast_mode=True,
280-
max_split_per_batch = max_split_per_batch,
280+
max_split_per_batch=max_split_per_batch,
281281
)
282282

283283
def test_absorb_decode_bf16():

0 commit comments

Comments
 (0)