fix: Sampling: CUDA Graph fix#2432
Conversation
📝 WalkthroughWalkthroughReplaced scalar philox_seed/philox_offset with optional (tensor, scalar) seed/offset tuples across Python API, C++ bindings/impl, and CUDA kernels; added validation and forwarding so kernels receive either a device pointer (tensor) or nullptr plus scalar fallback values. Changes
Sequence Diagram(s)mermaid Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ❌ 3❌ Failed checks (2 warnings, 1 inconclusive)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @IzzyPutterman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances FlashInfer's sampling functionalities by introducing a more flexible and robust mechanism for handling random number generator (RNG) seeds and offsets. Previously, only scalar seed and offset values were supported, limiting the ability to manage randomness on a per-batch basis. The updated implementation now allows users to provide optional Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable change by allowing seed and offset to be passed as tensors to sampling functions, which is crucial for CUDA graph compatibility. The overall approach is sound. However, I've identified a critical bug in the CUDA kernels where per-request seeds/offsets from a tensor are not correctly indexed, always using the first element. I've provided a suggestion to fix this in the CUDA kernels. Additionally, I've suggested an improvement in the Python validation logic to robustly handle both broadcast (size-1) and per-request (size-batch_size) tensors for seeds and offsets, which complements the kernel-side fix.
| uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; | ||
| uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; |
There was a problem hiding this comment.
This implementation incorrectly uses seed_arr[0] and offset_arr[0] for all items in the batch. To support per-request seeds/offsets when a tensor is provided, you should use the block index bx to access the corresponding element for each batch item. This would be seed_arr[bx] and offset_arr[bx].
This same issue exists in all the updated sampling kernels in this file (SamplingFromProbKernel, TopKSamplingFromProbKernel, etc.) and should be fixed in all of them.
uint64_t philox_seed = seed_arr ? seed_arr[bx] : seed_val;
uint64_t philox_offset = offset_arr ? offset_arr[bx] : offset_val;
| if maybe_seed_arr is not None: | ||
| if maybe_seed_arr.device != device: | ||
| raise ValueError(f"seed tensor must be on {device}") | ||
| if maybe_seed_arr.dtype not in [torch.int64, torch.uint64]: | ||
| raise ValueError("seed tensor must be int64/uint64") | ||
| if maybe_seed_arr.ndim != 1: | ||
| raise ValueError("seed tensor must be 1D") | ||
| if maybe_seed_arr.size(0) not in [1, batch_size]: | ||
| raise ValueError(f"seed tensor length must be 1 or {batch_size}") | ||
| if maybe_offset_arr is not None: | ||
| if maybe_offset_arr.device != device: | ||
| raise ValueError(f"offset tensor must be on {device}") | ||
| if maybe_offset_arr.dtype not in [torch.int64, torch.uint64]: | ||
| raise ValueError("offset tensor must be int64/uint64") | ||
| if maybe_offset_arr.ndim != 1: | ||
| raise ValueError("offset tensor must be 1D") | ||
| if maybe_offset_arr.size(0) not in [1, batch_size]: | ||
| raise ValueError(f"offset tensor length must be 1 or {batch_size}") |
There was a problem hiding this comment.
To robustly support both broadcast (size-1) and per-request (size-batch_size) tensors for seed and offset, you can expand the size-1 tensors to match the batch size. This creates a view with a stride of 0, so when the CUDA kernel accesses seed_arr[bx], it will correctly get the same value for all batch items. This makes the Python-side behavior consistent with the proposed change in the CUDA kernels to use [bx] for per-request access.
if maybe_seed_arr is not None:
if maybe_seed_arr.device != device:
raise ValueError(f"seed tensor must be on {device}")
if maybe_seed_arr.dtype not in [torch.int64, torch.uint64]:
raise ValueError("seed tensor must be int64/uint64")
if maybe_seed_arr.ndim != 1:
raise ValueError("seed tensor must be 1D")
if maybe_seed_arr.size(0) == 1:
maybe_seed_arr = maybe_seed_arr.expand(batch_size)
elif maybe_seed_arr.size(0) != batch_size:
raise ValueError(f"seed tensor length must be 1 or {batch_size}")
if maybe_offset_arr is not None:
if maybe_offset_arr.device != device:
raise ValueError(f"offset tensor must be on {device}")
if maybe_offset_arr.dtype not in [torch.int64, torch.uint64]:
raise ValueError("offset tensor must be int64/uint64")
if maybe_offset_arr.ndim != 1:
raise ValueError("offset tensor must be 1D")
if maybe_offset_arr.size(0) == 1:
maybe_offset_arr = maybe_offset_arr.expand(batch_size)
elif maybe_offset_arr.size(0) != batch_size:
raise ValueError(f"offset tensor length must be 1 or {batch_size}")There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/sampling.py (2)
91-124: Update fake-op signatures to accept seed/offset.The custom ops now take
seed/offset, but the registered fake ops still use the old parameter list. In FakeTensor /torch.compilepaths this will raise due to extra args. Please add optionalseed/offsetparameters (defaultNone) to each_fake_*sampling op to keep schemas in sync.🔧 Example fix (apply similarly to all _fake_* sampling ops)
`@register_fake_op`("flashinfer::sampling_from_logits") def _fake_sampling_from_logits( logits: torch.Tensor, indices: Optional[torch.Tensor], deterministic: bool, - generator: Optional[torch.Generator], + generator: Optional[torch.Generator], + seed: Optional[Union[int, torch.Tensor]] = None, + offset: Optional[Union[int, torch.Tensor]] = None, ) -> torch.Tensor: batch_size = indices.size(0) if indices is not None else logits.size(0) out_dtype = indices.dtype if indices is not None else torch.int32 return torch.empty(batch_size, dtype=out_dtype, device=logits.device)Also applies to: 141-171, 188-227, 244-281, 298-336, 340-384, 503-552
710-746: Docstrings still document seed/offset asOptional[int].The signatures now accept
Union[int, torch.Tensor], but the parameter docs still sayOptional[int]. Please update the docs (and include expected tensor shape/dtype) for each affected API to avoid user confusion.✏️ Example doc update (apply across affected APIs)
- seed: Optional[int] - seed value to use for the rng during the sampling operation. - offset: Optional[int] - offset value to use for the rng during the sampling operation. + seed: Optional[Union[int, torch.Tensor]] + Seed value or 1D int64/uint64 tensor (length 1 or batch_size). + offset: Optional[Union[int, torch.Tensor]] + Offset value or 1D int64/uint64 tensor (length 1 or batch_size).Also applies to: 777-813, 850-896, 947-993, 1044-1091, 1137-1193, 1270-1327, 1611-1660
🤖 Fix all issues with AI agents
In `@csrc/sampling.cu`:
- Around line 25-45: validate_seed_offset_tensors currently checks
dtype/ndim/device but not tensor length, allowing zero-length or mismatched
batch-length tensors which can cause OOB reads; update
validate_seed_offset_tensors to also check that when maybe_seed_arr/
maybe_offset_arr is present their size(0) is either 1 or equals the expected
batch size (pass the expected batch size or the output tensor and compare
against output.size(0) when indices are supplied), and update call sites that
invoke validate_seed_offset_tensors to supply the output tensor or explicit
batch size so the function can enforce length == 1 || length ==
expected_batch_size.
In `@flashinfer/sampling.py`:
- Around line 591-649: The code uses torch.uint64 (e.g., in
_validate_and_convert_seed_offset) without guarding for PyTorch versions prior
to 2.3; either declare PyTorch>=2.3 in project docs/requirements or add runtime
guards: check torch.__version__ or use getattr(torch, "uint64", None) and only
include torch.uint64 in dtype checks when present, otherwise fall back to
validating against torch.int64 only; apply the same pattern where torch.uint64
is referenced (flashinfer/sampling.py::_validate_and_convert_seed_offset,
flashinfer/comm/dlpack_utils.py, flashinfer/jit/utils.py,
flashinfer/deep_gemm.py) and update validation error messages to reflect the
supported dtype set dynamically.
In `@include/flashinfer/sampling.cuh`:
- Around line 703-711: The kernel SamplingFromLogitsKernel (and the other
sampling kernels listed) currently always reads philox_seed/philox_offset from
seed_arr[0]/offset_arr[0], which ignores per-batch seeds; change the resolution
logic to use the batch index (e.g., philox_seed = seed_arr ? seed_arr[bx] :
seed_val and philox_offset = offset_arr ? offset_arr[bx] : offset_val) or, if
the intended semantics are per-row, use the appropriate row index (row_idx)
instead of 0; apply the same fix to all sampling kernels referenced (lines
around 750-759, 803-813, 926-936, 1043-1052, 1129-1140, 1788-1801) so
per-request seeds/offsets are honored.
| // Helper function to validate seed/offset tensors for sampling operations | ||
| inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr, | ||
| const Optional<TensorView>& maybe_offset_arr, | ||
| const TensorView& reference_tensor) { | ||
| if (maybe_seed_arr.has_value()) { | ||
| CHECK_INPUT(maybe_seed_arr.value()); | ||
| CHECK_DIM(1, maybe_seed_arr.value()); | ||
| TVM_FFI_ICHECK(maybe_seed_arr.value().dtype() == dl_int64 || | ||
| maybe_seed_arr.value().dtype() == dl_uint64) | ||
| << "seed tensor must be int64 or uint64"; | ||
| CHECK_DEVICE(maybe_seed_arr.value(), reference_tensor); | ||
| } | ||
| if (maybe_offset_arr.has_value()) { | ||
| CHECK_INPUT(maybe_offset_arr.value()); | ||
| CHECK_DIM(1, maybe_offset_arr.value()); | ||
| TVM_FFI_ICHECK(maybe_offset_arr.value().dtype() == dl_int64 || | ||
| maybe_offset_arr.value().dtype() == dl_uint64) | ||
| << "offset tensor must be int64 or uint64"; | ||
| CHECK_DEVICE(maybe_offset_arr.value(), reference_tensor); | ||
| } | ||
| } |
There was a problem hiding this comment.
Validate seed/offset tensor lengths to prevent OOB reads.
The helper checks dtype/device/ndim but not length. A zero-length tensor or mismatched batch length will pass and kernels read element 0 (or bx), causing OOB. Please enforce length == 1 or == output batch size (use output.size(0) when indices is supplied).
🔧 Suggested validation (and call-site update)
-inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
- const Optional<TensorView>& maybe_offset_arr,
- const TensorView& reference_tensor) {
+inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
+ const Optional<TensorView>& maybe_offset_arr,
+ const TensorView& reference_tensor,
+ int64_t batch_size) {
if (maybe_seed_arr.has_value()) {
CHECK_INPUT(maybe_seed_arr.value());
CHECK_DIM(1, maybe_seed_arr.value());
+ TVM_FFI_ICHECK(maybe_seed_arr.value().size(0) == 1 ||
+ maybe_seed_arr.value().size(0) == batch_size)
+ << "seed tensor length must be 1 or batch_size";
TVM_FFI_ICHECK(maybe_seed_arr.value().dtype() == dl_int64 ||
maybe_seed_arr.value().dtype() == dl_uint64)
<< "seed tensor must be int64 or uint64";
CHECK_DEVICE(maybe_seed_arr.value(), reference_tensor);
}
if (maybe_offset_arr.has_value()) {
CHECK_INPUT(maybe_offset_arr.value());
CHECK_DIM(1, maybe_offset_arr.value());
+ TVM_FFI_ICHECK(maybe_offset_arr.value().size(0) == 1 ||
+ maybe_offset_arr.value().size(0) == batch_size)
+ << "offset tensor length must be 1 or batch_size";
TVM_FFI_ICHECK(maybe_offset_arr.value().dtype() == dl_int64 ||
maybe_offset_arr.value().dtype() == dl_uint64)
<< "offset tensor must be int64 or uint64";
CHECK_DEVICE(maybe_offset_arr.value(), reference_tensor);
}
}- validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits);
+ validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits, output.size(0));🤖 Prompt for AI Agents
In `@csrc/sampling.cu` around lines 25 - 45, validate_seed_offset_tensors
currently checks dtype/ndim/device but not tensor length, allowing zero-length
or mismatched batch-length tensors which can cause OOB reads; update
validate_seed_offset_tensors to also check that when maybe_seed_arr/
maybe_offset_arr is present their size(0) is either 1 or equals the expected
batch size (pass the expected batch size or the output tensor and compare
against output.size(0) when indices are supplied), and update call sites that
invoke validate_seed_offset_tensors to supply the output tensor or explicit
batch size so the function can enforce length == 1 || length ==
expected_batch_size.
| def _validate_and_convert_seed_offset( | ||
| seed: Union[int, torch.Tensor], | ||
| offset: Union[int, torch.Tensor], | ||
| device: torch.device, | ||
| batch_size: int, | ||
| ) -> Tuple[Optional[torch.Tensor], int, Optional[torch.Tensor], int]: | ||
| """Validate and convert seed/offset to tensor/scalar tuples for sampling kernels. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| seed : Union[int, torch.Tensor] | ||
| Seed value or tensor. | ||
| offset : Union[int, torch.Tensor] | ||
| Offset value or tensor. | ||
| device : torch.device | ||
| Expected device for tensor inputs. | ||
| batch_size : int | ||
| Expected batch size for tensor length validation. | ||
|
|
||
| Returns | ||
| ------- | ||
| Tuple[Optional[torch.Tensor], int, Optional[torch.Tensor], int] | ||
| (maybe_seed_arr, seed_val, maybe_offset_arr, offset_val) | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If seed and offset are not both tensors or both scalars, or if tensor | ||
| properties (device, dtype, ndim, size) are invalid. | ||
| """ | ||
| # Validate tensor/scalar consistency | ||
| if isinstance(seed, torch.Tensor) != isinstance(offset, torch.Tensor): | ||
| raise ValueError("seed and offset must both be tensors or both be scalars") | ||
|
|
||
| # Convert to tensor/scalar tuple | ||
| maybe_seed_arr, seed_val = _to_tensor_scalar_tuple(seed) | ||
| maybe_offset_arr, offset_val = _to_tensor_scalar_tuple(offset) | ||
|
|
||
| # Validate tensor properties | ||
| if maybe_seed_arr is not None: | ||
| if maybe_seed_arr.device != device: | ||
| raise ValueError(f"seed tensor must be on {device}") | ||
| if maybe_seed_arr.dtype not in [torch.int64, torch.uint64]: | ||
| raise ValueError("seed tensor must be int64/uint64") | ||
| if maybe_seed_arr.ndim != 1: | ||
| raise ValueError("seed tensor must be 1D") | ||
| if maybe_seed_arr.size(0) not in [1, batch_size]: | ||
| raise ValueError(f"seed tensor length must be 1 or {batch_size}") | ||
| if maybe_offset_arr is not None: | ||
| if maybe_offset_arr.device != device: | ||
| raise ValueError(f"offset tensor must be on {device}") | ||
| if maybe_offset_arr.dtype not in [torch.int64, torch.uint64]: | ||
| raise ValueError("offset tensor must be int64/uint64") | ||
| if maybe_offset_arr.ndim != 1: | ||
| raise ValueError("offset tensor must be 1D") | ||
| if maybe_offset_arr.size(0) not in [1, batch_size]: | ||
| raise ValueError(f"offset tensor length must be 1 or {batch_size}") | ||
|
|
||
| return maybe_seed_arr, seed_val, maybe_offset_arr, offset_val |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
torch.uint64 PyTorch dtype availability version
💡 Result:
torch.uint64 (along with torch.uint16 and torch.uint32) became available in PyTorch 2.3 as an unsigned “shell dtype”—i.e., it exists as a dtype value, but has limited operator/backend support (mainly tensor creation and shape-only/view-like ops; many elementwise/math ops may be unsupported or only partially supported). [1], [2]
So:
- Availability:
torch.uint64is available starting PyTorch 2.3. [1] - Status: It’s a shell dtype with limited eager-mode support, primarily to aid
torch.compile/interoperability rather than full eager kernel coverage. [2]
Sources: [1] GitHub issue discussing the PyTorch 2.3 introduction, [2] PyTorch “Tensor Attributes” docs listing torch.uint64 as a shell dtype.
🏁 Script executed:
# Check for PyTorch version requirements in the repository
find . -type f \( -name "setup.py" -o -name "pyproject.toml" -o -name "requirements.txt" -o -name "setup.cfg" \) | head -20Repository: flashinfer-ai/flashinfer
Length of output: 199
🏁 Script executed:
# Search for torch.uint64 usage in the codebase
rg "torch\.uint64" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 1524
🏁 Script executed:
# Search for hasattr guards related to torch types
rg "hasattr\(torch" --type pyRepository: flashinfer-ai/flashinfer
Length of output: 417
🏁 Script executed:
# Check the main pyproject.toml for PyTorch version requirements
cat pyproject.toml | head -100Repository: flashinfer-ai/flashinfer
Length of output: 2660
🏁 Script executed:
# Also check requirements.txt
cat requirements.txtRepository: flashinfer-ai/flashinfer
Length of output: 255
🏁 Script executed:
# Look for torch version pins in setup.py or other config
rg "torch.*=|PyTorch" --type py --type toml | head -20Repository: flashinfer-ai/flashinfer
Length of output: 2271
🏁 Script executed:
# Check for any CI configuration files that might specify PyTorch versions
find . -type f \( -name "*.yml" -o -name "*.yaml" \) | head -20Repository: flashinfer-ai/flashinfer
Length of output: 605
🏁 Script executed:
# Check if there's a minimum PyTorch version requirement in CI files
find .github -type f -name "*.yml" -o -name "*.yaml" 2>/dev/null | xargs grep -l "torch" 2>/dev/null | head -5Repository: flashinfer-ai/flashinfer
Length of output: 135
🏁 Script executed:
# Look for documentation or installation instructions that mention PyTorch version
find . -type f \( -name "README.md" -o -name "INSTALL.md" -o -name "*.md" \) | xargs grep -i "pytorch.*version\|torch.*>=\|torch.*2\." 2>/dev/null | head -20Repository: flashinfer-ai/flashinfer
Length of output: 1383
🏁 Script executed:
# Check the release notes or documentation for minimum PyTorch version
cat README.md | head -200Repository: flashinfer-ai/flashinfer
Length of output: 7734
🏁 Script executed:
# Check CI/CD configurations to see what PyTorch versions are tested
cat .github/workflows/pr-test.yml | grep -A 5 -B 5 "torch\|python" | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for any version constraint discussions in the source code
rg "torch\.__version__|version_info" --type py -A 2Repository: flashinfer-ai/flashinfer
Length of output: 639
🏁 Script executed:
# Check the CI test configurations to see what PyTorch versions are tested
cat .github/workflows/pr-test.ymlRepository: flashinfer-ai/flashinfer
Length of output: 13548
🏁 Script executed:
# Check if there are any version constraints or compatibility notes for PyTorch
rg "PyTorch|torch.*2\." --type md | head -30Repository: flashinfer-ai/flashinfer
Length of output: 2431
🏁 Script executed:
# Look for any documentation about supported PyTorch versions
find . -name "*.md" -type f -exec grep -l "version\|support\|compatible" {} \; 2>/dev/null | xargs grep -i "torch\|pytorch" 2>/dev/null | grep -i "version\|support\|2\." | head -20Repository: flashinfer-ai/flashinfer
Length of output: 2568
🏁 Script executed:
# Check if there's a minimum PyTorch version anywhere in setup.py, pyproject.toml, or other config
grep -r "torch.*>=" --include="*.py" --include="*.toml" --include="*.txt" | grep -v ".git" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 2257
🏁 Script executed:
# Check the CHANGELOG more carefully for version support claims
grep -i "torch\|pytorch" CHANGELOG.md | head -40Repository: flashinfer-ai/flashinfer
Length of output: 2682
🏁 Script executed:
# Check if torch.uint64 is used elsewhere in the code and if there are any guards
rg "uint64" --type py -B 2 -A 2 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 4751
🏁 Script executed:
# Check if there's any explicit minimum version statement anywhere
rg "minimum.*torch|torch.*minimum|requires.*torch" --type md --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 529
🏁 Script executed:
# Look at the conftest.py to understand version checks
cat tests/conftest.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 3644
🏁 Script executed:
# Look for any issue or discussion about minimum PyTorch version
rg "torch.*2\.[0-9]" --type md | grep -E "minimum|support|require|compatible" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 313
🏁 Script executed:
# Check if there are explicit version checks in the codebase that hint at minimum version
rg "TorchVersion|parse.*torch|torch.*version" --type py -A 2 -B 2 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2407
🏁 Script executed:
# Check when torch.uint64 was introduced - look at git history of sampling.py to see when uint64 was added
git log --oneline -p flashinfer/sampling.py | grep -B 5 -A 5 "torch.uint64" | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1149
🏁 Script executed:
# More directly: check if there's a version file or any explicit minimum version declaration
find . -name "*version*" -type f | grep -v ".git" | grep -v "__pycache__"Repository: flashinfer-ai/flashinfer
Length of output: 340
🏁 Script executed:
# Check what happens if we try to use the code - see if there are any imports or early version checks
head -50 flashinfer/sampling.pyRepository: flashinfer-ai/flashinfer
Length of output: 1558
🏁 Script executed:
# Check if torch.uint64 appears in any version guards or conditional imports
rg "torch.uint64|uint64" --type py -B 5 -A 5 | grep -E "hasattr|version|if|else" | head -40Repository: flashinfer-ai/flashinfer
Length of output: 1265
🏁 Script executed:
# Check if there's a way the code could fail gracefully
cat flashinfer/sampling.py | grep -A 20 -B 5 "torch.uint64"Repository: flashinfer-ai/flashinfer
Length of output: 1518
torch.uint64 is only available in PyTorch 2.3+; validate minimum version requirements.
The code references torch.uint64, which became available only in PyTorch 2.3 as a shell dtype with limited support. However, this is used unguarded not only in _validate_and_convert_seed_offset() but also throughout the codebase (e.g., flashinfer/comm/dlpack_utils.py, flashinfer/jit/utils.py, flashinfer/deep_gemm.py).
Clarify the minimum supported PyTorch version. If FlashInfer requires PyTorch 2.3+, document this constraint. If older versions must be supported, add guards consistently across all modules that use torch.uint64.
🧰 Tools
🪛 Ruff (0.14.14)
623-623: Avoid specifying long messages outside the exception class
(TRY003)
632-632: Avoid specifying long messages outside the exception class
(TRY003)
634-634: Avoid specifying long messages outside the exception class
(TRY003)
636-636: Avoid specifying long messages outside the exception class
(TRY003)
638-638: Avoid specifying long messages outside the exception class
(TRY003)
641-641: Avoid specifying long messages outside the exception class
(TRY003)
643-643: Avoid specifying long messages outside the exception class
(TRY003)
645-645: Avoid specifying long messages outside the exception class
(TRY003)
647-647: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@flashinfer/sampling.py` around lines 591 - 649, The code uses torch.uint64
(e.g., in _validate_and_convert_seed_offset) without guarding for PyTorch
versions prior to 2.3; either declare PyTorch>=2.3 in project docs/requirements
or add runtime guards: check torch.__version__ or use getattr(torch, "uint64",
None) and only include torch.uint64 in dtype checks when present, otherwise fall
back to validating against torch.int64 only; apply the same pattern where
torch.uint64 is referenced
(flashinfer/sampling.py::_validate_and_convert_seed_offset,
flashinfer/comm/dlpack_utils.py, flashinfer/jit/utils.py,
flashinfer/deep_gemm.py) and update validation error messages to reflect the
supported dtype set dynamically.
| __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* indices, uint32_t d, | ||
| uint64_t philox_seed, uint64_t philox_offset) { | ||
| uint64_t* seed_arr, uint64_t seed_val, | ||
| uint64_t* offset_arr, uint64_t offset_val) { | ||
| const uint32_t bx = blockIdx.x, tx = threadIdx.x; | ||
|
|
||
| // Resolve seed/offset from tensor or scalar | ||
| uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; | ||
| uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; | ||
|
|
There was a problem hiding this comment.
Per-batch seed/offset tensors are ignored (always using element 0).
All kernels resolve philox_seed/offset using seed_arr[0] / offset_arr[0]. Since higher-level validation allows length == batch_size, per-request seeds are silently ignored. Either (a) restrict validation to length 1 only, or (b) index by bx (or row_idx if that’s the intended semantics).
🔧 Example fix (apply to all sampling kernels)
- uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val;
- uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val;
+ uint64_t philox_seed = seed_arr ? seed_arr[bx] : seed_val;
+ uint64_t philox_offset = offset_arr ? offset_arr[bx] : offset_val;Also applies to: 750-759, 803-813, 926-936, 1043-1052, 1129-1140, 1788-1801
🤖 Prompt for AI Agents
In `@include/flashinfer/sampling.cuh` around lines 703 - 711, The kernel
SamplingFromLogitsKernel (and the other sampling kernels listed) currently
always reads philox_seed/philox_offset from seed_arr[0]/offset_arr[0], which
ignores per-batch seeds; change the resolution logic to use the batch index
(e.g., philox_seed = seed_arr ? seed_arr[bx] : seed_val and philox_offset =
offset_arr ? offset_arr[bx] : offset_val) or, if the intended semantics are
per-row, use the appropriate row index (row_idx) instead of 0; apply the same
fix to all sampling kernels referenced (lines around 750-759, 803-813, 926-936,
1043-1052, 1129-1140, 1788-1801) so per-request seeds/offsets are honored.
|
@flashinfer-bot run |
|
/bot run |
|
[CANCELING] Pipeline #43126072: canceled |
|
/bot run |
|
@IzzyPutterman is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
|
||
| // Resolve seed/offset from tensor or scalar | ||
| uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val; | ||
| uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val; |
There was a problem hiding this comment.
Some block need to advance the offset value (like in
flashinfer/flashinfer/sampling.py
Lines 33 to 49 in 6ae5bfe
There was a problem hiding this comment.
Hmmm, I guess this PR only allows for CG compliance with passing in the seed and offset. I would argue that its the user's responsibility to update the seed and offset as they use this function. Otherwise there could be a lack of visibility into how the RNG is actually seeded.
No idea if this is a reasonable take, but I'm curious on your thoughts (and if I missed something super obvious lol).
There was a problem hiding this comment.
Spoke offline, confirmed this behavior is ok with the doc strings added.
|
[CANCELING] Pipeline #43221767: canceled |
d958a6c to
8248582
Compare
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/sampling.py (2)
106-113:⚠️ Potential issue | 🟡 Minor
orcondition silently discards user-provided tensor when only one of seed/offset is given.
if seed is None or offset is Nonereplaces both values with generator-derived scalars even when the user explicitly provided a tensor for one of them. For example,sampling_from_logits(logits, seed=my_cuda_tensor)silently dropsmy_cuda_tensor. This is especially surprising for the new tensor path since creating a CUDA tensor indicates deliberate intent for CUDA graph capture.Consider using
and+ an explicit error for the partial case. This same pattern repeats in all internal functions (Lines 154, 207, 261, 317, 362, 530).Suggested fix
- if seed is None or offset is None: + if seed is None and offset is None: seed, offset = get_seed_and_offset( batch_size * logits.size(1), generator, device ) + elif seed is None or offset is None: + raise ValueError("Both seed and offset must be provided, or neither.")
127-136:⚠️ Potential issue | 🟡 MinorFake op signatures are currently mismatch with real ops—prepare before re-enabling
torch.library.The real ops (e.g.,
sampling_from_logits) now acceptseedandoffsetas optional parameters, but the corresponding fake ops (_fake_sampling_from_logits,_fake_sampling_from_probs,_fake_top_p_sampling_from_probs, etc.) do not include these parameters.This is not currently breaking
torch.compilebecauseregister_custom_opandregister_fake_opare disabled inflashinfer/utils.py(the torch.library integration is commented out). However, once this code is re-enabled, the signature mismatch will prevent proper fake op schema validation. Addseedandoffsetparameters to all fake ops in preparation for re-enabling torch.library integration.
|
/bot run |
1 similar comment
|
/bot run |
|
[FAILED] Pipeline #43702970: 8/20 passed |
8248582 to
4ae57e6
Compare
yzh119
left a comment
There was a problem hiding this comment.
It should be ready to merge once CI passed, thanks @IzzyPutterman for working on the fix!
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit