Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/attention/test_xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def set_random_seed(seed=42):
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.random.manual_seed_all(seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function set_random_seed already calls torch.cuda.manual_seed_all(seed) on line 14. The added line torch.cuda.random.manual_seed_all(seed) is a duplicate, as torch.cuda.random.manual_seed_all is an alias for torch.cuda.manual_seed_all. This redundant call should be removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In recent pytorch versions,torch.manual_seed(seed) should cover the semantics of torch.cuda.random.manual_seed_all and there is no need to set gpu seed explicitly: https://docs.pytorch.org/docs/stable/notes/randomness.html#pytorch-random-number-generator



def round_up(a, b):
Expand Down Expand Up @@ -268,7 +269,9 @@ def test_xqa(

# Shuffle page indices
flattened = page_list_arg.flatten()
indices = torch.randperm(flattened.numel(), device="cuda")
generator = torch.Generator(device="cuda")
generator.manual_seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The seed 42 is hardcoded here. It's also used in the set_random_seed(42) call at the beginning of this test on line 177. To improve maintainability and avoid magic numbers, consider defining a constant for the seed (e.g., SEED = 42) at the module level and using it in both places.

indices = torch.randperm(flattened.numel(), generator=generator, device="cuda")
Comment on lines +276 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

🧩 Analysis chain

Verify determinism in the MLA test.

The xqa test now uses a seeded generator for torch.randperm (good for determinism), but the xqa_mla test at line 522 still uses the default CUDA RNG without an explicit generator. This inconsistency could lead to non-deterministic behavior in the MLA test.


🏁 Script executed:

#!/bin/bash
# Verify that xqa_mla test uses deterministic RNG for randperm
rg -n -A2 -B2 'def test_xqa_mla' tests/attention/test_xqa.py
rg -n 'torch\.randperm' tests/attention/test_xqa.py

Length of output: 322


Add seeded generator to test_xqa_mla for deterministic RNG.

The xqa test uses a seeded generator for torch.randperm (line 273), but the xqa_mla test at line 522 still calls torch.randperm without an explicit generator, falling back to the default CUDA RNG. Apply the same generator pattern to test_xqa_mla to ensure deterministic behavior across both tests.

πŸ€– Prompt for AI Agents
In tests/attention/test_xqa.py around lines 271-273 and at the xqa_mla call near
line 522, the CUDA RNG is not consistently seeded; create a seeded CUDA
generator (generator = torch.Generator(device="cuda");
generator.manual_seed(42)) and pass it into torch.randperm as
generator=generator (keeping device="cuda") in the xqa_mla test so both tests
use the same deterministic RNG source.

shuffled_flat = flattened[indices]
page_list_arg = shuffled_flat.view(batch_size, nb_pages_per_seq)

Expand Down Expand Up @@ -335,6 +338,9 @@ def test_xqa(

rcp_out_scale = 4.0 if use_fp8_output else 1.0

torch.cuda.synchronize()
semaphores.zero_()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Critical: MLA test missing synchronization.

The xqa test now includes torch.cuda.synchronize() and semaphores.zero_() before the kernel callβ€”critical additions for ensuring proper ordering and clean state. However, the xqa_mla test (starting line 565) does not include these synchronization calls. Given that this PR aims to fix flaky xqa tests, the missing synchronization in the MLA test is a significant oversight that could cause flakiness.

Apply similar synchronization to the MLA test:

# Add before line 565 (before xqa_mla call)
torch.cuda.synchronize()
semaphores.zero_()
πŸ€– Prompt for AI Agents
In tests/attention/test_xqa.py around lines 340-342 and specifically for the
xqa_mla test starting at line 565, the MLA variant is missing the GPU
synchronization and semaphore reset that were added for the xqa test; before
calling xqa_mla at ~line 565 add a torch.cuda.synchronize() call followed by
semaphores.zero_() (using the same semaphores variable used elsewhere) to ensure
proper ordering and a clean semaphore state before launching the kernel.

xqa(
q_heads,
cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads,
Expand All @@ -347,15 +353,17 @@ def test_xqa(
nb_k_heads,
tokens_per_page,
sinks=attention_sinks,
q_scale=q_scale,
kv_scale=kv_cache_scale,
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),
Comment on lines +357 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Critical: MLA test not updated with tensor scales.

The xqa test now passes q_scale and kv_scale as CUDA tensors, aligning with the kernel changes (array indexing in mla_sm120.cu). However, the xqa_mla test at lines 575-576 still passes these as Python scalars. This inconsistency could cause runtime errors or incorrect behavior in the MLA path.

Update the xqa_mla test to use tensor scales:

# Update lines 575-576 in xqa_mla call
q_scale=torch.tensor(q_scale, device="cuda"),
kv_scale=torch.tensor(kv_cache_scale, device="cuda"),
πŸ€– Prompt for AI Agents
In tests/attention/test_xqa.py around lines 575 to 576, the xqa_mla test still
passes q_scale and kv_scale as Python scalars while the rest of the tests (and
kernel changes) expect CUDA tensors; update the xqa_mla call to wrap both scales
with torch.tensor(..., device="cuda") so q_scale and kv_scale are passed as CUDA
tensors (matching the change at lines 355-356 and preventing MLA path
runtime/type errors).

sliding_win_size=sliding_win_size,
kv_layout=kv_layout,
sm_count=sm_count,
enable_pdl=enable_pdl,
rcp_out_scale=rcp_out_scale,
)

torch.cuda.synchronize()

for req in range(batch_size):
for b in range(beam_width):
for idx_k_head in range(nb_k_heads):
Expand Down