-
Notifications
You must be signed in to change notification settings - Fork 38
Fix private memory size too large in sample_recovered_tokens_kernel #115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix private memory size too large in sample_recovered_tokens_kernel #115
Conversation
Signed-off-by: Xin Li <[email protected]>
Signed-off-by: Xin Li <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix a memory size or pointer value too large error in a Triton kernel. The main change is the introduction of a new, blocked implementation for sample_recovered_tokens_kernel to manage memory usage better. My review focuses on ensuring the new kernel is robust against potential integer overflows during pointer arithmetic, which is a likely cause of the original error. I've identified two areas where 32-bit integer overflows could still occur and have provided suggestions to cast to 64-bit integers to prevent this, enhancing the correctness and reliability of the fix.
| max_prob = -float('inf') | ||
| best_token_id = 0 | ||
|
|
||
| for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE): | ||
| block_end = min(block_start + BLOCK_SIZE, vocab_size) | ||
|
|
||
| vocab_offset = tl.arange(0, BLOCK_SIZE) | ||
| mask = vocab_offset < block_end - block_start | ||
|
|
||
| if NO_DRAFT_PROBS: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| prob = tl.load( | ||
| target_probs_ptr + (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=(mask & (vocab_offset + block_start != draft_token_id)), | ||
| other=0) | ||
|
|
||
| else: | ||
| draft_prob = tl.load(draft_probs_ptr + | ||
| (start_idx + pos) * vocab_size + block_start + | ||
| vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| target_prob = tl.load(target_probs_ptr + | ||
| (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| prob = tl.maximum(target_prob - draft_prob, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer offset calculation (start_idx + pos) * vocab_size is repeated and may suffer from 32-bit integer overflow if start_idx is loaded as a 32-bit integer and multiplied by a large vocab_size. This can lead to incorrect memory access and is a likely cause for the pointer value too large to fit in 32 bit error.
To ensure correctness and improve readability, it's better to calculate the base offset once outside the loop, explicitly casting to tl.int64 to prevent any potential overflow.
token_idx = start_idx + pos
# Cast to int64 to prevent overflow when calculating pointer offsets.
base_offset = token_idx.to(tl.int64) * vocab_size
max_prob = -float('inf')
best_token_id = 0
for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
block_end = min(block_start + BLOCK_SIZE, vocab_size)
vocab_offset = tl.arange(0, BLOCK_SIZE)
mask = vocab_offset < block_end - block_start
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
prob = tl.load(
target_probs_ptr + base_offset + block_start + vocab_offset,
mask=(mask & (vocab_offset + block_start != draft_token_id)),
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + base_offset + block_start +
vocab_offset,
mask=mask,
other=0)
target_prob = tl.load(target_probs_ptr + base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)| q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset, | ||
| mask=mask, | ||
| other=float("-inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency and to prevent potential 32-bit integer overflows, the offset calculation for q_ptr should also use 64-bit integers. While req_idx is likely small, multiplying by a large vocab_size could still pose a risk on some platforms or with very large batches. Using tl.int64 ensures the calculation is safe.
| q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset, | |
| mask=mask, | |
| other=float("-inf")) | |
| q = tl.load(q_ptr + req_idx.to(tl.int64) * vocab_size + block_start + vocab_offset, | |
| mask=mask, | |
| other=float("-inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix a Triton Error related to memory size and pointer values in sample_recovered_tokens_kernel. The approach of iterating over the vocabulary in blocks is a good solution for the private memory size issue. However, the fix seems incomplete as it doesn't address the potential for 32-bit integer overflow in pointer offset calculations. I've added a suggestion to explicitly use 64-bit integers for these calculations to make the fix robust. The other changes in the pull request are correct.
|
|
||
| max_prob = -float('inf') | ||
| best_token_id = 0 | ||
|
|
||
| for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE): | ||
| block_end = min(block_start + BLOCK_SIZE, vocab_size) | ||
|
|
||
| vocab_offset = tl.arange(0, BLOCK_SIZE) | ||
| mask = vocab_offset < block_end - block_start | ||
|
|
||
| if NO_DRAFT_PROBS: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| prob = tl.load( | ||
| target_probs_ptr + (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=(mask & (vocab_offset + block_start != draft_token_id)), | ||
| other=0) | ||
|
|
||
| else: | ||
| draft_prob = tl.load(draft_probs_ptr + | ||
| (start_idx + pos) * vocab_size + block_start + | ||
| vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| target_prob = tl.load(target_probs_ptr + | ||
| (start_idx + pos) * vocab_size + | ||
| block_start + vocab_offset, | ||
| mask=mask, | ||
| other=0) | ||
| prob = tl.maximum(target_prob - draft_prob, 0) | ||
|
|
||
| # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because | ||
| # `tl.argmax` will select the maximum value. | ||
|
|
||
| q = tl.load(q_ptr + req_idx * vocab_size + block_start + vocab_offset, | ||
| mask=mask, | ||
| other=float("-inf")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR title and description mention fixing RuntimeError: Triton Error [MACA]: memory size or pointer value too large to fit in 32 bit. While iterating over the vocabulary in blocks fixes the "private memory size too large" issue, the "pointer value too large" issue can still occur.
The pointer offset calculations like (start_idx + pos) * vocab_size and req_idx * vocab_size can overflow a 32-bit integer if the total number of tokens or the vocabulary size is large, leading to the error.
This suggestion ensures 64-bit arithmetic for offset calculations by explicitly casting to tl.int64 before multiplication. This prevents potential overflows and fully resolves the reported error, making the kernel more robust.
token_idx = start_idx + pos
# Ensure 64-bit arithmetic for large offsets to prevent overflow.
probs_base_offset = token_idx.to(tl.int64) * vocab_size
q_base_offset = req_idx.to(tl.int64) * vocab_size
max_prob = -float('inf')
best_token_id = 0
for block_start in range(0, PADDED_VOCAB_SIZE, BLOCK_SIZE):
block_end = min(block_start + BLOCK_SIZE, vocab_size)
vocab_offset = tl.arange(0, BLOCK_SIZE)
mask = vocab_offset < block_end - block_start
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
prob = tl.load(
target_probs_ptr + probs_base_offset + block_start +
vocab_offset,
mask=(mask & (vocab_offset + block_start != draft_token_id)),
other=0)
else:
draft_prob = tl.load(draft_probs_ptr + probs_base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
target_prob = tl.load(target_probs_ptr + probs_base_offset +
block_start + vocab_offset,
mask=mask,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + q_base_offset + block_start + vocab_offset,
mask=mask,
other=float("-inf"))Signed-off-by: Xin Li <[email protected]>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]>
…l` (#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]> Signed-off-by: leex404 <[email protected]>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]>
#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]>
…l` (#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]> Signed-off-by: leex404 <[email protected]>
* support platform and remove kernel copy Signed-off-by: Hank <[email protected]> * update pre-commit Signed-off-by: Hank <[email protected]> * update version and requirements Signed-off-by: Hank <[email protected]> * update flashinfer Signed-off-by: Hank <[email protected]> * update build requirements Signed-off-by: Hank <[email protected]> * update attention backends Signed-off-by: Hank <[email protected]> * update patch Signed-off-by: Hank <[email protected]> * update quant_method Signed-off-by: Hank <[email protected]> * update fuse_moe (todo: fix mypy) Signed-off-by: Hank <[email protected]> * update `deepseek_v2.py`(todo: fix indexer kernel) Signed-off-by: Hank <[email protected]> * [feat] support bf16 cp_gather_indexer_k_cache kernel Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: leex404 <[email protected]> * [feat] add topk logits ops Signed-off-by: leex404 <[email protected]> * [fix] private memory size too large in `sample_recovered_tokens_kernel` (#115) * [fix] fix sample_recovered_tokens_kernel use too much private memory Signed-off-by: Xin Li <[email protected]> * [fix] fix type error in bf16_paged_mqa_logits Signed-off-by: Xin Li <[email protected]> * [chore] change file directory Signed-off-by: Xin Li <[email protected]> --------- Signed-off-by: Xin Li <[email protected]> Co-authored-by: Xin Li <[email protected]> Signed-off-by: leex404 <[email protected]> * [fix] fix missing topk logits custom ops definition Signed-off-by: leex404 <[email protected]> * [fix] add custom gptq_shuffle ops Signed-off-by: leex404 <[email protected]> * [fix] fix compile error Signed-off-by: leex404 <[email protected]> * platform config update Signed-off-by: Hank <[email protected]> * update qwen2.5_vl model Signed-off-by: Hank <[email protected]> * [fix] fix torch not found maca device Signed-off-by: leex404 <[email protected]> * remove hotfixes patch for torch2.8 Signed-off-by: Hank <[email protected]> * remove needless patch related: vllm-project/vllm/pull/27322 Signed-off-by: Hank <[email protected]> * [feat] topk_softmax support renormalize and bf16 Signed-off-by: leex404 <[email protected]> * [fix] update fused_moe to fit v0.11.1 Signed-off-by: leex404 <[email protected]> * [fix] fix fused moe config log missing Signed-off-by: leex404 <[email protected]> * use flash_attn as vit attn backend on qwen_vl Signed-off-by: Hank <[email protected]> * update quant_conf registry Signed-off-by: Hank <[email protected]> * fix and apply latest pre-commit of v0.11.1 Signed-off-by: Hank <[email protected]> * [feat] Keep all AITER kernels in _aiter_ops Signed-off-by: leex404 <[email protected]> * fix pre-commit on type casting Signed-off-by: Hank <[email protected]> * [fix] fix DeepSeek import error Signed-off-by: leex404 <[email protected]> * [feat] update deepseek_v2 to fit v0.11.1 Signed-off-by: leex404 <[email protected]> --------- Signed-off-by: Hank <[email protected]> Signed-off-by: Xin Li <[email protected]> Signed-off-by: leex404 <[email protected]> Co-authored-by: Xin Li <[email protected]> Co-authored-by: leex404 <[email protected]> Co-authored-by: leex404 <[email protected]>
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
fix error of
RuntimeError: Triton Error [MACA]: memory size or pointer value too large to fit in 32 bitTest Plan
Test Result
(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.