-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Open
Description
Problem Description
There appears to be an incorrect index mapping in the seqlens_expanded calculation for multi-request batches in NSA backend. The TopK kernel accesses wrong logits positions due to this mismatch.
Code Location
python/sglang/srt/layers/attention/nsa_backend.py, lines 244-258:
seqlens_expanded = torch.cat([
torch.arange(
kv_len - qo_len + 1, # Start position
kv_len + 1, # End position
dtype=torch.int32,
device=device,
)
for qo_len, kv_len in zip(
forward_batch.extend_seq_lens_cpu,
forward_batch.seq_lens_cpu.tolist(),
strict=True,
)
])Concrete Example
Consider a batch with two requests:
- Request 1: 40 tokens (KV cache positions 0-39)
- Request 2: 37 tokens (KV cache positions 40-76)
Current Behavior (Incorrect)
seqlens_expanded calculation:
- Request 1:
torch.arange(40-40+1, 40+1) = [1, 2, 3, ..., 40] - Request 2:
torch.arange(37-37+1, 37+1) = [1, 2, 3, ..., 37] - Final:
[1, 2, 3, ..., 40, 1, 2, 3, ..., 37]
Problem:
- Request 2's first token is at batch position 40
seqlens_expanded[40] = 1- TopK kernel accesses
logits[40, 0:1] - This corresponds to KV position [0], but it should be KV position [40]
Expected Behavior (Correct)
Request 2's first token should:
- Only attend to its own KV position [40]
- Have
seqlens_expanded[40]indicate access to the correct logits range - TopK should select from logits corresponding to KV position [40], not [0]
But why are the final results still correct? Am I missing something?
Metadata
Metadata
Assignees
Labels
No labels