Skip to content

fix: Sampling: CUDA Graph fix#2432

Merged
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
IzzyPutterman:iputterman/sampling-cuda-graphs
Feb 13, 2026
Merged

fix: Sampling: CUDA Graph fix#2432
yzh119 merged 2 commits intoflashinfer-ai:mainfrom
IzzyPutterman:iputterman/sampling-cuda-graphs

Conversation

@IzzyPutterman
Copy link
Contributor

@IzzyPutterman IzzyPutterman commented Jan 29, 2026

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor
    • Sampling APIs now accept seed and offset as either scalars or 1D tensors; tensor-based per-call seeds/offsets are threaded through native sampling paths with scalar fallbacks preserved.
  • Quality
    • Centralized validation/conversion enforces dtype, device, shape/length and batch-size semantics; validation applied before processing.
  • Documentation
    • Docstrings updated with union-type signatures, usage guidance and CUDA-graph compatibility notes.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

Replaced scalar philox_seed/philox_offset with optional (tensor, scalar) seed/offset tuples across Python API, C++ bindings/impl, and CUDA kernels; added validation and forwarding so kernels receive either a device pointer (tensor) or nullptr plus scalar fallback values.

Changes

Cohort / File(s) Summary
Python API
flashinfer/sampling.py
Seed/offset parameters now accept Optional[Union[int, torch.Tensor]]; added _validate_and_convert_seed_offset() and normalization helpers; convert inputs to (maybe_seed_arr, seed_val, maybe_offset_arr, offset_val) for backend calls and updated docstrings.
C++ Bindings
csrc/flashinfer_sampling_binding.cu
Exported function signatures updated to replace philox_seed/philox_offset with Optional<TensorView> maybe_seed_arr, uint64_t seed_val, Optional<TensorView> maybe_offset_arr, uint64_t offset_val; export declarations adjusted.
C++ Implementation
csrc/sampling.cu
Added validate_seed_offset_tensors() helper; entry points updated to accept optional seed/offset tensors, validate them, and forward either device pointer or nullptr plus scalar fallback to CUDA launches; kernel launch args updated and kernel status checks preserved.
CUDA Headers / Kernels
include/flashinfer/sampling.cuh
Kernel and host-wrapper signatures changed to accept (uint64_t* seed_arr, uint64_t seed_val, uint64_t* offset_arr, uint64_t offset_val); in-kernel resolution derives philox values from provided pointers or uses scalar fallbacks; launch argument arrays updated throughout.
Overall Sampling Surface
.../
Consistent signature changes for sampling functions (sampling_from_logits/probs, top_k/top_p/min_p, top_k_top_p, chain_speculative_sampling) propagated Python → C++ bindings → C++ impl → CUDA; added validation/conversion and forwarding logic.

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Py as Python API
participant Bind as C++ Binding
participant Impl as C++ Impl
participant CUDA as CUDA Kernel
Py->>Bind: call sampling(..., seed=int|tensor, offset=int|tensor)
Bind->>Impl: forward (maybe_seed_arr, seed_val, maybe_offset_arr, offset_val)
Impl->>Impl: validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr)
Impl->>CUDA: launch kernel with (seed_ptr or nullptr, seed_val, offset_ptr or nullptr, offset_val)
CUDA->>CUDA: resolve philox_seed/philox_offset from ptrs or fallbacks
CUDA-->>Impl: return results
Impl-->>Bind: return
Bind-->>Py: return

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • cyx-6
  • nvmbreughe
  • djmmoss
  • bkryu
  • kahyunnam
  • jiahanc

Poem

🐇 I nudge the seed with whiskered care,

Scalars bow as tensors hop in pair,
I check each shape and device with glee,
Kernels hum and randoms roam free,
A little hop for sampling harmony.

🚥 Pre-merge checks | ❌ 3
❌ Failed checks (2 warnings, 1 inconclusive)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description only contains the empty template with checkboxes marked as complete, but lacks substantive details about the code changes, rationale, and modified files. Fill in the Description and Related Issues sections with details about replacing scalar seed/offset parameters with optional tensor inputs, explain why this is needed for CUDA graphs, and provide relevant issue links.
Docstring Coverage ⚠️ Warning Docstring coverage is 39.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'fix: Sampling: CUDA Graph fix' is vague and does not clearly describe the specific changes made in the pull request. Revise the title to be more specific about what CUDA Graph issue was fixed, such as 'fix: Support per-call seed/offset tensors in sampling for CUDA graph compatibility' or similar.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

No actionable comments were generated in the recent review. 🎉

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @IzzyPutterman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances FlashInfer's sampling functionalities by introducing a more flexible and robust mechanism for handling random number generator (RNG) seeds and offsets. Previously, only scalar seed and offset values were supported, limiting the ability to manage randomness on a per-batch basis. The updated implementation now allows users to provide optional torch.Tensor inputs for seeds and offsets, which is critical for enabling CUDA Graph compatibility and facilitating more advanced sampling strategies where dynamic, device-side RNG state is required. This change ensures that sampling operations can be efficiently captured and executed within CUDA Graphs while maintaining backward compatibility for existing scalar inputs.

Highlights

  • Flexible RNG State Input: Sampling functions now accept both scalar uint64_t and optional torch.Tensor inputs for random number generator (RNG) seeds and offsets, allowing for per-batch control of randomness.
  • CUDA Graph Compatibility: This change enables FlashInfer's sampling operations to be compatible with CUDA Graphs by allowing RNG state to be managed on the device, which is a requirement for graph capture.
  • Input Validation: A new helper function _validate_and_convert_seed_offset (Python) and validate_seed_offset_tensors (C++) has been introduced to ensure the correctness and consistency of seed and offset inputs, whether they are scalars or tensors.
  • Unified Kernel Interface: The underlying CUDA kernels have been updated to dynamically resolve the RNG seed and offset from either the provided tensor (for per-batch randomness) or the scalar fallback (for global randomness).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable change by allowing seed and offset to be passed as tensors to sampling functions, which is crucial for CUDA graph compatibility. The overall approach is sound. However, I've identified a critical bug in the CUDA kernels where per-request seeds/offsets from a tensor are not correctly indexed, always using the first element. I've provided a suggestion to fix this in the CUDA kernels. Additionally, I've suggested an improvement in the Python validation logic to robustly handle both broadcast (size-1) and per-request (size-batch_size) tensors for seeds and offsets, which complements the kernel-side fix.

Comment on lines +709 to +710
uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val;
uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation incorrectly uses seed_arr[0] and offset_arr[0] for all items in the batch. To support per-request seeds/offsets when a tensor is provided, you should use the block index bx to access the corresponding element for each batch item. This would be seed_arr[bx] and offset_arr[bx].

This same issue exists in all the updated sampling kernels in this file (SamplingFromProbKernel, TopKSamplingFromProbKernel, etc.) and should be fixed in all of them.

  uint64_t philox_seed = seed_arr ? seed_arr[bx] : seed_val;
  uint64_t philox_offset = offset_arr ? offset_arr[bx] : offset_val;

Comment on lines +630 to +647
if maybe_seed_arr is not None:
if maybe_seed_arr.device != device:
raise ValueError(f"seed tensor must be on {device}")
if maybe_seed_arr.dtype not in [torch.int64, torch.uint64]:
raise ValueError("seed tensor must be int64/uint64")
if maybe_seed_arr.ndim != 1:
raise ValueError("seed tensor must be 1D")
if maybe_seed_arr.size(0) not in [1, batch_size]:
raise ValueError(f"seed tensor length must be 1 or {batch_size}")
if maybe_offset_arr is not None:
if maybe_offset_arr.device != device:
raise ValueError(f"offset tensor must be on {device}")
if maybe_offset_arr.dtype not in [torch.int64, torch.uint64]:
raise ValueError("offset tensor must be int64/uint64")
if maybe_offset_arr.ndim != 1:
raise ValueError("offset tensor must be 1D")
if maybe_offset_arr.size(0) not in [1, batch_size]:
raise ValueError(f"offset tensor length must be 1 or {batch_size}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To robustly support both broadcast (size-1) and per-request (size-batch_size) tensors for seed and offset, you can expand the size-1 tensors to match the batch size. This creates a view with a stride of 0, so when the CUDA kernel accesses seed_arr[bx], it will correctly get the same value for all batch items. This makes the Python-side behavior consistent with the proposed change in the CUDA kernels to use [bx] for per-request access.

    if maybe_seed_arr is not None:
        if maybe_seed_arr.device != device:
            raise ValueError(f"seed tensor must be on {device}")
        if maybe_seed_arr.dtype not in [torch.int64, torch.uint64]:
            raise ValueError("seed tensor must be int64/uint64")
        if maybe_seed_arr.ndim != 1:
            raise ValueError("seed tensor must be 1D")
        if maybe_seed_arr.size(0) == 1:
            maybe_seed_arr = maybe_seed_arr.expand(batch_size)
        elif maybe_seed_arr.size(0) != batch_size:
            raise ValueError(f"seed tensor length must be 1 or {batch_size}")
    if maybe_offset_arr is not None:
        if maybe_offset_arr.device != device:
            raise ValueError(f"offset tensor must be on {device}")
        if maybe_offset_arr.dtype not in [torch.int64, torch.uint64]:
            raise ValueError("offset tensor must be int64/uint64")
        if maybe_offset_arr.ndim != 1:
            raise ValueError("offset tensor must be 1D")
        if maybe_offset_arr.size(0) == 1:
            maybe_offset_arr = maybe_offset_arr.expand(batch_size)
        elif maybe_offset_arr.size(0) != batch_size:
            raise ValueError(f"offset tensor length must be 1 or {batch_size}")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/sampling.py (2)

91-124: Update fake-op signatures to accept seed/offset.

The custom ops now take seed/offset, but the registered fake ops still use the old parameter list. In FakeTensor / torch.compile paths this will raise due to extra args. Please add optional seed/offset parameters (default None) to each _fake_* sampling op to keep schemas in sync.

🔧 Example fix (apply similarly to all _fake_* sampling ops)
 `@register_fake_op`("flashinfer::sampling_from_logits")
 def _fake_sampling_from_logits(
     logits: torch.Tensor,
     indices: Optional[torch.Tensor],
     deterministic: bool,
-    generator: Optional[torch.Generator],
+    generator: Optional[torch.Generator],
+    seed: Optional[Union[int, torch.Tensor]] = None,
+    offset: Optional[Union[int, torch.Tensor]] = None,
 ) -> torch.Tensor:
     batch_size = indices.size(0) if indices is not None else logits.size(0)
     out_dtype = indices.dtype if indices is not None else torch.int32
     return torch.empty(batch_size, dtype=out_dtype, device=logits.device)

Also applies to: 141-171, 188-227, 244-281, 298-336, 340-384, 503-552


710-746: Docstrings still document seed/offset as Optional[int].

The signatures now accept Union[int, torch.Tensor], but the parameter docs still say Optional[int]. Please update the docs (and include expected tensor shape/dtype) for each affected API to avoid user confusion.

✏️ Example doc update (apply across affected APIs)
-    seed: Optional[int]
-        seed value to use for the rng during the sampling operation.
-    offset: Optional[int]
-        offset value to use for the rng during the sampling operation.
+    seed: Optional[Union[int, torch.Tensor]]
+        Seed value or 1D int64/uint64 tensor (length 1 or batch_size).
+    offset: Optional[Union[int, torch.Tensor]]
+        Offset value or 1D int64/uint64 tensor (length 1 or batch_size).

Also applies to: 777-813, 850-896, 947-993, 1044-1091, 1137-1193, 1270-1327, 1611-1660

🤖 Fix all issues with AI agents
In `@csrc/sampling.cu`:
- Around line 25-45: validate_seed_offset_tensors currently checks
dtype/ndim/device but not tensor length, allowing zero-length or mismatched
batch-length tensors which can cause OOB reads; update
validate_seed_offset_tensors to also check that when maybe_seed_arr/
maybe_offset_arr is present their size(0) is either 1 or equals the expected
batch size (pass the expected batch size or the output tensor and compare
against output.size(0) when indices are supplied), and update call sites that
invoke validate_seed_offset_tensors to supply the output tensor or explicit
batch size so the function can enforce length == 1 || length ==
expected_batch_size.

In `@flashinfer/sampling.py`:
- Around line 591-649: The code uses torch.uint64 (e.g., in
_validate_and_convert_seed_offset) without guarding for PyTorch versions prior
to 2.3; either declare PyTorch>=2.3 in project docs/requirements or add runtime
guards: check torch.__version__ or use getattr(torch, "uint64", None) and only
include torch.uint64 in dtype checks when present, otherwise fall back to
validating against torch.int64 only; apply the same pattern where torch.uint64
is referenced (flashinfer/sampling.py::_validate_and_convert_seed_offset,
flashinfer/comm/dlpack_utils.py, flashinfer/jit/utils.py,
flashinfer/deep_gemm.py) and update validation error messages to reflect the
supported dtype set dynamically.

In `@include/flashinfer/sampling.cuh`:
- Around line 703-711: The kernel SamplingFromLogitsKernel (and the other
sampling kernels listed) currently always reads philox_seed/philox_offset from
seed_arr[0]/offset_arr[0], which ignores per-batch seeds; change the resolution
logic to use the batch index (e.g., philox_seed = seed_arr ? seed_arr[bx] :
seed_val and philox_offset = offset_arr ? offset_arr[bx] : offset_val) or, if
the intended semantics are per-row, use the appropriate row index (row_idx)
instead of 0; apply the same fix to all sampling kernels referenced (lines
around 750-759, 803-813, 926-936, 1043-1052, 1129-1140, 1788-1801) so
per-request seeds/offsets are honored.

Comment on lines +25 to +45
// Helper function to validate seed/offset tensors for sampling operations
inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
const Optional<TensorView>& maybe_offset_arr,
const TensorView& reference_tensor) {
if (maybe_seed_arr.has_value()) {
CHECK_INPUT(maybe_seed_arr.value());
CHECK_DIM(1, maybe_seed_arr.value());
TVM_FFI_ICHECK(maybe_seed_arr.value().dtype() == dl_int64 ||
maybe_seed_arr.value().dtype() == dl_uint64)
<< "seed tensor must be int64 or uint64";
CHECK_DEVICE(maybe_seed_arr.value(), reference_tensor);
}
if (maybe_offset_arr.has_value()) {
CHECK_INPUT(maybe_offset_arr.value());
CHECK_DIM(1, maybe_offset_arr.value());
TVM_FFI_ICHECK(maybe_offset_arr.value().dtype() == dl_int64 ||
maybe_offset_arr.value().dtype() == dl_uint64)
<< "offset tensor must be int64 or uint64";
CHECK_DEVICE(maybe_offset_arr.value(), reference_tensor);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate seed/offset tensor lengths to prevent OOB reads.

The helper checks dtype/device/ndim but not length. A zero-length tensor or mismatched batch length will pass and kernels read element 0 (or bx), causing OOB. Please enforce length == 1 or == output batch size (use output.size(0) when indices is supplied).

🔧 Suggested validation (and call-site update)
-inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
-                                         const Optional<TensorView>& maybe_offset_arr,
-                                         const TensorView& reference_tensor) {
+inline void validate_seed_offset_tensors(const Optional<TensorView>& maybe_seed_arr,
+                                         const Optional<TensorView>& maybe_offset_arr,
+                                         const TensorView& reference_tensor,
+                                         int64_t batch_size) {
   if (maybe_seed_arr.has_value()) {
     CHECK_INPUT(maybe_seed_arr.value());
     CHECK_DIM(1, maybe_seed_arr.value());
+    TVM_FFI_ICHECK(maybe_seed_arr.value().size(0) == 1 ||
+                   maybe_seed_arr.value().size(0) == batch_size)
+        << "seed tensor length must be 1 or batch_size";
     TVM_FFI_ICHECK(maybe_seed_arr.value().dtype() == dl_int64 ||
                    maybe_seed_arr.value().dtype() == dl_uint64)
         << "seed tensor must be int64 or uint64";
     CHECK_DEVICE(maybe_seed_arr.value(), reference_tensor);
   }
   if (maybe_offset_arr.has_value()) {
     CHECK_INPUT(maybe_offset_arr.value());
     CHECK_DIM(1, maybe_offset_arr.value());
+    TVM_FFI_ICHECK(maybe_offset_arr.value().size(0) == 1 ||
+                   maybe_offset_arr.value().size(0) == batch_size)
+        << "offset tensor length must be 1 or batch_size";
     TVM_FFI_ICHECK(maybe_offset_arr.value().dtype() == dl_int64 ||
                    maybe_offset_arr.value().dtype() == dl_uint64)
         << "offset tensor must be int64 or uint64";
     CHECK_DEVICE(maybe_offset_arr.value(), reference_tensor);
   }
 }
-  validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits);
+  validate_seed_offset_tensors(maybe_seed_arr, maybe_offset_arr, logits, output.size(0));
🤖 Prompt for AI Agents
In `@csrc/sampling.cu` around lines 25 - 45, validate_seed_offset_tensors
currently checks dtype/ndim/device but not tensor length, allowing zero-length
or mismatched batch-length tensors which can cause OOB reads; update
validate_seed_offset_tensors to also check that when maybe_seed_arr/
maybe_offset_arr is present their size(0) is either 1 or equals the expected
batch size (pass the expected batch size or the output tensor and compare
against output.size(0) when indices are supplied), and update call sites that
invoke validate_seed_offset_tensors to supply the output tensor or explicit
batch size so the function can enforce length == 1 || length ==
expected_batch_size.

Comment on lines +591 to +649
def _validate_and_convert_seed_offset(
seed: Union[int, torch.Tensor],
offset: Union[int, torch.Tensor],
device: torch.device,
batch_size: int,
) -> Tuple[Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
"""Validate and convert seed/offset to tensor/scalar tuples for sampling kernels.

Parameters
----------
seed : Union[int, torch.Tensor]
Seed value or tensor.
offset : Union[int, torch.Tensor]
Offset value or tensor.
device : torch.device
Expected device for tensor inputs.
batch_size : int
Expected batch size for tensor length validation.

Returns
-------
Tuple[Optional[torch.Tensor], int, Optional[torch.Tensor], int]
(maybe_seed_arr, seed_val, maybe_offset_arr, offset_val)

Raises
------
ValueError
If seed and offset are not both tensors or both scalars, or if tensor
properties (device, dtype, ndim, size) are invalid.
"""
# Validate tensor/scalar consistency
if isinstance(seed, torch.Tensor) != isinstance(offset, torch.Tensor):
raise ValueError("seed and offset must both be tensors or both be scalars")

# Convert to tensor/scalar tuple
maybe_seed_arr, seed_val = _to_tensor_scalar_tuple(seed)
maybe_offset_arr, offset_val = _to_tensor_scalar_tuple(offset)

# Validate tensor properties
if maybe_seed_arr is not None:
if maybe_seed_arr.device != device:
raise ValueError(f"seed tensor must be on {device}")
if maybe_seed_arr.dtype not in [torch.int64, torch.uint64]:
raise ValueError("seed tensor must be int64/uint64")
if maybe_seed_arr.ndim != 1:
raise ValueError("seed tensor must be 1D")
if maybe_seed_arr.size(0) not in [1, batch_size]:
raise ValueError(f"seed tensor length must be 1 or {batch_size}")
if maybe_offset_arr is not None:
if maybe_offset_arr.device != device:
raise ValueError(f"offset tensor must be on {device}")
if maybe_offset_arr.dtype not in [torch.int64, torch.uint64]:
raise ValueError("offset tensor must be int64/uint64")
if maybe_offset_arr.ndim != 1:
raise ValueError("offset tensor must be 1D")
if maybe_offset_arr.size(0) not in [1, batch_size]:
raise ValueError(f"offset tensor length must be 1 or {batch_size}")

return maybe_seed_arr, seed_val, maybe_offset_arr, offset_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

torch.uint64 PyTorch dtype availability version

💡 Result:

torch.uint64 (along with torch.uint16 and torch.uint32) became available in PyTorch 2.3 as an unsigned “shell dtype”—i.e., it exists as a dtype value, but has limited operator/backend support (mainly tensor creation and shape-only/view-like ops; many elementwise/math ops may be unsupported or only partially supported). [1], [2]

So:

  • Availability: torch.uint64 is available starting PyTorch 2.3. [1]
  • Status: It’s a shell dtype with limited eager-mode support, primarily to aid torch.compile/interoperability rather than full eager kernel coverage. [2]

Sources: [1] GitHub issue discussing the PyTorch 2.3 introduction, [2] PyTorch “Tensor Attributes” docs listing torch.uint64 as a shell dtype.


🏁 Script executed:

# Check for PyTorch version requirements in the repository
find . -type f \( -name "setup.py" -o -name "pyproject.toml" -o -name "requirements.txt" -o -name "setup.cfg" \) | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 199


🏁 Script executed:

# Search for torch.uint64 usage in the codebase
rg "torch\.uint64" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 1524


🏁 Script executed:

# Search for hasattr guards related to torch types
rg "hasattr\(torch" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 417


🏁 Script executed:

# Check the main pyproject.toml for PyTorch version requirements
cat pyproject.toml | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 2660


🏁 Script executed:

# Also check requirements.txt
cat requirements.txt

Repository: flashinfer-ai/flashinfer

Length of output: 255


🏁 Script executed:

# Look for torch version pins in setup.py or other config
rg "torch.*=|PyTorch" --type py --type toml | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 2271


🏁 Script executed:

# Check for any CI configuration files that might specify PyTorch versions
find . -type f \( -name "*.yml" -o -name "*.yaml" \) | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 605


🏁 Script executed:

# Check if there's a minimum PyTorch version requirement in CI files
find .github -type f -name "*.yml" -o -name "*.yaml" 2>/dev/null | xargs grep -l "torch" 2>/dev/null | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 135


🏁 Script executed:

# Look for documentation or installation instructions that mention PyTorch version
find . -type f \( -name "README.md" -o -name "INSTALL.md" -o -name "*.md" \) | xargs grep -i "pytorch.*version\|torch.*>=\|torch.*2\." 2>/dev/null | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1383


🏁 Script executed:

# Check the release notes or documentation for minimum PyTorch version
cat README.md | head -200

Repository: flashinfer-ai/flashinfer

Length of output: 7734


🏁 Script executed:

# Check CI/CD configurations to see what PyTorch versions are tested
cat .github/workflows/pr-test.yml | grep -A 5 -B 5 "torch\|python" | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for any version constraint discussions in the source code
rg "torch\.__version__|version_info" --type py -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 639


🏁 Script executed:

# Check the CI test configurations to see what PyTorch versions are tested
cat .github/workflows/pr-test.yml

Repository: flashinfer-ai/flashinfer

Length of output: 13548


🏁 Script executed:

# Check if there are any version constraints or compatibility notes for PyTorch
rg "PyTorch|torch.*2\." --type md | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 2431


🏁 Script executed:

# Look for any documentation about supported PyTorch versions
find . -name "*.md" -type f -exec grep -l "version\|support\|compatible" {} \; 2>/dev/null | xargs grep -i "torch\|pytorch" 2>/dev/null | grep -i "version\|support\|2\." | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 2568


🏁 Script executed:

# Check if there's a minimum PyTorch version anywhere in setup.py, pyproject.toml, or other config
grep -r "torch.*>=" --include="*.py" --include="*.toml" --include="*.txt" | grep -v ".git" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 2257


🏁 Script executed:

# Check the CHANGELOG more carefully for version support claims
grep -i "torch\|pytorch" CHANGELOG.md | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 2682


🏁 Script executed:

# Check if torch.uint64 is used elsewhere in the code and if there are any guards
rg "uint64" --type py -B 2 -A 2 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 4751


🏁 Script executed:

# Check if there's any explicit minimum version statement anywhere
rg "minimum.*torch|torch.*minimum|requires.*torch" --type md --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 529


🏁 Script executed:

# Look at the conftest.py to understand version checks
cat tests/conftest.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 3644


🏁 Script executed:

# Look for any issue or discussion about minimum PyTorch version
rg "torch.*2\.[0-9]" --type md | grep -E "minimum|support|require|compatible" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 313


🏁 Script executed:

# Check if there are explicit version checks in the codebase that hint at minimum version
rg "TorchVersion|parse.*torch|torch.*version" --type py -A 2 -B 2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2407


🏁 Script executed:

# Check when torch.uint64 was introduced - look at git history of sampling.py to see when uint64 was added
git log --oneline -p flashinfer/sampling.py | grep -B 5 -A 5 "torch.uint64" | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1149


🏁 Script executed:

# More directly: check if there's a version file or any explicit minimum version declaration
find . -name "*version*" -type f | grep -v ".git" | grep -v "__pycache__"

Repository: flashinfer-ai/flashinfer

Length of output: 340


🏁 Script executed:

# Check what happens if we try to use the code - see if there are any imports or early version checks
head -50 flashinfer/sampling.py

Repository: flashinfer-ai/flashinfer

Length of output: 1558


🏁 Script executed:

# Check if torch.uint64 appears in any version guards or conditional imports
rg "torch.uint64|uint64" --type py -B 5 -A 5 | grep -E "hasattr|version|if|else" | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1265


🏁 Script executed:

# Check if there's a way the code could fail gracefully
cat flashinfer/sampling.py | grep -A 20 -B 5 "torch.uint64"

Repository: flashinfer-ai/flashinfer

Length of output: 1518


torch.uint64 is only available in PyTorch 2.3+; validate minimum version requirements.

The code references torch.uint64, which became available only in PyTorch 2.3 as a shell dtype with limited support. However, this is used unguarded not only in _validate_and_convert_seed_offset() but also throughout the codebase (e.g., flashinfer/comm/dlpack_utils.py, flashinfer/jit/utils.py, flashinfer/deep_gemm.py).

Clarify the minimum supported PyTorch version. If FlashInfer requires PyTorch 2.3+, document this constraint. If older versions must be supported, add guards consistently across all modules that use torch.uint64.

🧰 Tools
🪛 Ruff (0.14.14)

623-623: Avoid specifying long messages outside the exception class

(TRY003)


632-632: Avoid specifying long messages outside the exception class

(TRY003)


634-634: Avoid specifying long messages outside the exception class

(TRY003)


636-636: Avoid specifying long messages outside the exception class

(TRY003)


638-638: Avoid specifying long messages outside the exception class

(TRY003)


641-641: Avoid specifying long messages outside the exception class

(TRY003)


643-643: Avoid specifying long messages outside the exception class

(TRY003)


645-645: Avoid specifying long messages outside the exception class

(TRY003)


647-647: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@flashinfer/sampling.py` around lines 591 - 649, The code uses torch.uint64
(e.g., in _validate_and_convert_seed_offset) without guarding for PyTorch
versions prior to 2.3; either declare PyTorch>=2.3 in project docs/requirements
or add runtime guards: check torch.__version__ or use getattr(torch, "uint64",
None) and only include torch.uint64 in dtype checks when present, otherwise fall
back to validating against torch.int64 only; apply the same pattern where
torch.uint64 is referenced
(flashinfer/sampling.py::_validate_and_convert_seed_offset,
flashinfer/comm/dlpack_utils.py, flashinfer/jit/utils.py,
flashinfer/deep_gemm.py) and update validation error messages to reflect the
supported dtype set dynamically.

Comment on lines 703 to +711
__global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* indices, uint32_t d,
uint64_t philox_seed, uint64_t philox_offset) {
uint64_t* seed_arr, uint64_t seed_val,
uint64_t* offset_arr, uint64_t offset_val) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;

// Resolve seed/offset from tensor or scalar
uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val;
uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Per-batch seed/offset tensors are ignored (always using element 0).

All kernels resolve philox_seed/offset using seed_arr[0] / offset_arr[0]. Since higher-level validation allows length == batch_size, per-request seeds are silently ignored. Either (a) restrict validation to length 1 only, or (b) index by bx (or row_idx if that’s the intended semantics).

🔧 Example fix (apply to all sampling kernels)
-  uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val;
-  uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val;
+  uint64_t philox_seed = seed_arr ? seed_arr[bx] : seed_val;
+  uint64_t philox_offset = offset_arr ? offset_arr[bx] : offset_val;

Also applies to: 750-759, 803-813, 926-936, 1043-1052, 1129-1140, 1788-1801

🤖 Prompt for AI Agents
In `@include/flashinfer/sampling.cuh` around lines 703 - 711, The kernel
SamplingFromLogitsKernel (and the other sampling kernels listed) currently
always reads philox_seed/philox_offset from seed_arr[0]/offset_arr[0], which
ignores per-batch seeds; change the resolution logic to use the batch index
(e.g., philox_seed = seed_arr ? seed_arr[bx] : seed_val and philox_offset =
offset_arr ? offset_arr[bx] : offset_val) or, if the intended semantics are
per-row, use the appropriate row index (row_idx) instead of 0; apply the same
fix to all sampling kernels referenced (lines around 750-759, 803-813, 926-936,
1043-1052, 1129-1140, 1788-1801) so per-request seeds/offsets are honored.

@yzh119
Copy link
Collaborator

yzh119 commented Feb 2, 2026

@flashinfer-bot run

@yzh119
Copy link
Collaborator

yzh119 commented Feb 2, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !289 has been created, and the CI pipeline #43126072 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43126072: canceled

@yzh119 yzh119 added the run-ci label Feb 3, 2026
@IzzyPutterman
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@IzzyPutterman is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@yongwww
Copy link
Member

yongwww commented Feb 3, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !289 has been created, and the CI pipeline #43221767 is currently running. I'll report back once the pipeline job completes.


// Resolve seed/offset from tensor or scalar
uint64_t philox_seed = seed_arr ? seed_arr[0] : seed_val;
uint64_t philox_offset = offset_arr ? offset_arr[0] : offset_val;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some block need to advance the offset value (like in

def get_seed_and_offset(
increment: int,
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
) -> Tuple[int, int]:
if generator is None:
generator = get_default_generators(device)
# add mutex if multi-trheading needed
state = generator.get_state()
seed, offset = state.view(torch.int64)
offset += (increment + 3) // 4 * 4
generator.set_state(
torch.tensor(
[seed, offset], dtype=torch.int64, device=torch.device("cpu")
).view(torch.uint8)
)
return int(seed), int(offset)
), otherwise we will be keep using the same offset for multiple runs. Or we can advance the offset tensor after calling this operator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I guess this PR only allows for CG compliance with passing in the seed and offset. I would argue that its the user's responsibility to update the seed and offset as they use this function. Otherwise there could be a lack of visibility into how the RNG is actually seeded.
No idea if this is a reasonable take, but I'm curious on your thoughts (and if I missed something super obvious lol).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke offline, confirmed this behavior is ok with the doc strings added.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #43221767: canceled

@IzzyPutterman IzzyPutterman force-pushed the iputterman/sampling-cuda-graphs branch from d958a6c to 8248582 Compare February 6, 2026 00:48
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/sampling.py (2)

106-113: ⚠️ Potential issue | 🟡 Minor

or condition silently discards user-provided tensor when only one of seed/offset is given.

if seed is None or offset is None replaces both values with generator-derived scalars even when the user explicitly provided a tensor for one of them. For example, sampling_from_logits(logits, seed=my_cuda_tensor) silently drops my_cuda_tensor. This is especially surprising for the new tensor path since creating a CUDA tensor indicates deliberate intent for CUDA graph capture.

Consider using and + an explicit error for the partial case. This same pattern repeats in all internal functions (Lines 154, 207, 261, 317, 362, 530).

Suggested fix
-        if seed is None or offset is None:
+        if seed is None and offset is None:
             seed, offset = get_seed_and_offset(
                 batch_size * logits.size(1), generator, device
             )
+        elif seed is None or offset is None:
+            raise ValueError("Both seed and offset must be provided, or neither.")

127-136: ⚠️ Potential issue | 🟡 Minor

Fake op signatures are currently mismatch with real ops—prepare before re-enabling torch.library.

The real ops (e.g., sampling_from_logits) now accept seed and offset as optional parameters, but the corresponding fake ops (_fake_sampling_from_logits, _fake_sampling_from_probs, _fake_top_p_sampling_from_probs, etc.) do not include these parameters.

This is not currently breaking torch.compile because register_custom_op and register_fake_op are disabled in flashinfer/utils.py (the torch.library integration is commented out). However, once this code is re-enabled, the signature mismatch will prevent proper fake op schema validation. Add seed and offset parameters to all fake ops in preparation for re-enabling torch.library integration.

@dierksen
Copy link
Contributor

/bot run

1 similar comment
@nvmbreughe
Copy link
Contributor

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !289 has been updated with latest changes, and the CI pipeline #43702957 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !289 has been created, and the CI pipeline #43702970 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #43702970: 8/20 passed

@IzzyPutterman IzzyPutterman force-pushed the iputterman/sampling-cuda-graphs branch from 8248582 to 4ae57e6 Compare February 12, 2026 17:15
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be ready to merge once CI passed, thanks @IzzyPutterman for working on the fix!

@yzh119 yzh119 merged commit c5b8a2e into flashinfer-ai:main Feb 13, 2026
30 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants