Skip to content

Commit 6150ab4

Browse files
authored
Merge pull request vllm-project#3 from vllm-project/yewentao256-patch-1
[Bug] Fix Einsum in DeepGEMM tests
2 parents 6c4e0e8 + 8b0cb11 commit 6150ab4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

support_materials/3-cuda-kernels/indexer/tests/test_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def ref_fp8_mqa_logits(q: torch.Tensor,
6161
device='cuda')[None, :] < cu_seqlen_ke[:, None]
6262
mask = mask_lo & mask_hi
6363

64-
score = torch.einsum('mhd,and->hmn', q, k)
64+
score = torch.einsum('mhd,nd->hmn', q, k)
6565
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
6666
logits = logits.masked_fill(~mask, float('-inf'))
6767

0 commit comments

Comments
 (0)