Skip to content

Commit 6c4e0e8

Browse files
authored
Merge pull request vllm-project#2 from vllm-project/fix-2-tilelang
Fixes for support_materials/2-tilelang/
2 parents 9a317f7 + 4729f66 commit 6c4e0e8

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

support_materials/2-tilelang/1.index_attn_dynamic_qpack_varlen_fp8.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,10 @@ def index_attn_return_logits_interface(q,
220220
weights,
221221
cu_seqlen_ks,
222222
cu_seqlen_ke,
223-
clean_logits=True):
223+
should_clean_logits=True):
224224
seq_len, heads, index_dim = q.shape
225225
seq_len_kv = kv.shape[0]
226226

227-
clean_logits_kernel = clean_logits()
228-
229227
index_attn_return_logits_kernel = index_attn_return_logits(
230228
heads=heads, index_dim=index_dim)
231229
logits = torch.empty([seq_len, seq_len_kv],
@@ -239,7 +237,8 @@ def index_attn_return_logits_interface(q,
239237
cu_seqlen_ks,
240238
cu_seqlen_ke,
241239
)
242-
if clean_logits:
240+
if should_clean_logits:
241+
clean_logits_kernel = clean_logits()
243242
clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke)
244243
return logits
245244

@@ -262,7 +261,7 @@ def ref_fp8_mqa_logits(
262261
< cu_seqlen_ke[:, None])
263262
mask = mask_lo & mask_hi
264263

265-
score = torch.einsum("mhd,and->hmn", q, k)
264+
score = torch.einsum("mhd,nd->hmn", q, k)
266265
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
267266
logits = logits.masked_fill(~mask, float("-inf"))
268267

support_materials/2-tilelang/2.sparse_attn_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def test_sparse_attn_mla_fwd():
314314
def fn():
315315
return sparse_attention_fwd_interface(q, kv, indices)
316316

317-
from tilelang.testing import do_bench
317+
from tilelang.profiler import do_bench
318318

319319
ms = do_bench(
320320
fn,

0 commit comments

Comments
 (0)