Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Nov 18, 2025

No description provided.

@github-actions
Copy link

github-actions bot commented Nov 18, 2025

Review updated until commit 70e0885

Description

  • Proposes RaggedIterDomain, a new IterDomain subclass for representing ragged/jagged dimensions in nvFuser

  • Enables efficient compilation of PyTorch nested tensors without padding overhead

  • Details architecture, implementation phases, and system integration for ragged tensor support

  • Covers core abstractions including nested domains, offset computation, transformations, and code generation

Changes walkthrough

Relevant files
Documentation
ragged_iter_domain_design_doc.md
RaggedIterDomain RFC Design Document                                         

doc/dev/ragged_iter_domain_design_doc.md

  • New comprehensive design document (533 lines) proposing
    RaggedIterDomain implementation
  • Defines architecture for ragged dimensions with variable extents
    across batch components
  • Details PyTorch nested tensor semantics, nvFuser integration, and code
    generation strategy
  • Outlines implementation phases from core infrastructure to full
    production support
  • +533/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Implementation Complexity Underestimation

    The document outlines 4 implementation phases but may underestimate the complexity of IdModel integration (Phase 2). Ragged dimensions will require significant changes to the indexing system, ValGraph handling, and expression types. Consider adding more detailed technical specifications for IdModel extensions and potential risks.

    ### Phase 2: IdModel Integration
    - Extend IdModel ValGraph to handle RaggedIterDomain
    - Modify TensorIndexer for offset-based indexing
    - Add new expression types for ragged operations
    - Predicate generation for ragged bounds
    - **Goal**: Can compile and execute simple ragged operations
    Performance Validation Strategy

    While the document mentions expert parallelism as the goal, it lacks specific performance benchmarks or comparison targets (e.g., roofline model, CUTLASS performance). A clear performance validation strategy with quantitative goals should be established before implementation begins.

    **Scope Note**: This proposal represents a **minimalistic initial version** containing only the capabilities that are absolutely necessary for expert parallelism. The exact requirements for expert parallelism are still being clarified with Jingyue. As those requirements become clear, additional capabilities (such as flatten operations, specific IdModel integrations, or lowering support) may be added to this design. The features described here should be considered the bare minimum starting point.
    
    ---
    
    ## 2. Motivation
    
    Memory Management Complexity

    The offset tensor management approach (Section 6.7) introduces complexity around kernel launch dependencies and device memory management. The proposed preseg pass solution needs more detailed technical specification to ensure it handles all edge cases correctly, especially for dynamic nested tensor creation scenarios.

    **Key Observation**: The nested domain extents are computed **inside the kernel** and are not known at kernel launch time.
    
    **Implication**: We cannot bundle extent/offset information with the nested tensor itself.
    
    This problem can be addressed by managing the offsets as a separate tensor that can be computed dynamically on GPU and passed between kernels. That effectively means a logical nested tensor consists of two Vals: one tensor for the nested tensor itself and another tensor for the offsets. More concretely, here's a fusion that creates a nested tensor with `viewAsNested` as an output:
    
    ```cpp
    // User-defined Fusion
    Fusion fusion;
    FusionGuard fg(&fusion);
    
    // User provides data and offsets as separate inputs
    auto tv_data = TensorViewBuilder()
        .ndims(2)
        .shape({-1, 512})  // [total_tokens, hidden]
        .dtype(DataType::Float)
        .build();
    fusion.addInput(tv_data);
    
    auto tv_offsets = TensorViewBuilder()
        .ndims(1)
        .shape({9})  // [num_experts + 1]
        .dtype(DataType::Int)
        .build();
    fusion.addInput(tv_offsets);
    
    // User explicitly creates nested tensor view
    auto tv_nested = viewAsNested(tv_data, tv_offsets, /*ragged_dim=*/0);
    // tv_nested has shape [batch=8, ragged_tokens, hidden=512]
    
    // Operations on the nested tensor
    auto tv_result = some_operation(tv_nested);
    
    fusion.addOutput(tv_result);

    The output tensor, tv_result, is a nested tensor. The extents of the nested domains are given as a fusion input, but in general, they are not known until the fusion is executed. Thus, if the nested tensor struct were defined like:

    template <typename DT, int rank>
    struct NestedTensor {
    	DT* ptr;
    	int64_t extents[rank];
    	int64_t nested_domain_extents[ragged_dimension_rank];
    };

    The value of nested_domain_extents is not available until the completion of the kernel, which would block the launch of the subsequent kernel.

    Instead, we would like the fusion to be defined as follows:

    fusion.addInput(tv_data);      // Original data input (unchanged)
    fusion.addInput(tv_offsets);   // Original offset input (unchanged)
    
    auto tv_nested = viewAsNested(tv_data, tv_offsets, /*ragged_dim=*/0);
    auto tv_result = some_operation(tv_nested);
    
    auto tv_result_offsets = /* extract/compute offset part of tv_result */;
    
    fusion.addOutput(tv_result);      // Data tensor output
    fusion.addOutput(tv_result_offsets);   // Offset tensor output (injected)

    Here, for tv_result we would use the same Tensor struct as the normal tensor. The offset tensor would be a 1D tensor with the ptr val referring to the vector holding the offsets on the device memory. In this case, there's nothing to block the launch of the subsequent kernel as the offset vector would remain on the device memory.

    Since it is an implementation detail, the offset tensor should be hidden behind the nested tensor in the user-facing Fusion definition. When a user uses viewAsNested to create a nested tensor, it should still create a single nested tensor Val, as illustrated in the first case above. The translation to the second pattern should be done automatically, e.g., by a new preseg pass.

    
    </details>
    
    </td></tr>
    </table>
    
    
    </details>
    
    <!-- BEGIN INTERNAL PR REVIEW PLACEHOLDER -->
    <!-- END INTERNAL PR REVIEW PLACEHOLDER -->
    
    
    <!-- BEGIN CI TEST RESULTS PLACEHOLDER -->
    <!-- END CI TEST RESULTS PLACEHOLDER -->
    
    

    @naoyam naoyam changed the title RaggedIterDomain for nested tensors [RFC] RaggedIterDomain for nested tensors Nov 18, 2025
    naoyam and others added 9 commits November 17, 2025 18:44
    - Add select operation section explaining how to extract individual components
    - Fix section numbering (removed gaps in section sequence)
    - Fix broken link syntax for PyTorch documentation
    - Clarify flatten operation as TBD based on expert parallelism requirements
    - Remove references to non-existent appendices
    - Standardize terminology to "nested domains" throughout
    - Clarify DID as "distributed device (multi-GPU) parallelization"
    - Remove metadata header (Status, Author, Date fields)
    - Update PyTorch description to "one dimension" instead of "one or more"
    
    🤖 Generated with [Claude Code](https://claude.com/claude-code)
    
    Co-Authored-By: Claude <[email protected]>
    @naoyam naoyam marked this pull request as ready for review November 20, 2025 01:34
    @naoyam naoyam requested a review from wujingyue November 20, 2025 01:34
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Nov 20, 2025

    Greptile Summary

    • Introduces RaggedIterDomain as a new IterDomain subclass to support PyTorch nested tensors with variable-length dimensions for expert parallelism use cases
    • Proposes offset-based indexing and contiguous memory layout with separate offset tensor management for dynamic kernel computation scenarios

    Confidence Score: 4/5

    • This RFC is safe to approve as a design document with minor documentation improvements needed
    • The design is well-structured and comprehensive, with clear architecture, API specifications, and implementation phases. Minor issues include a missing override keyword in one method signature and inconsistent code indentation in examples. As an RFC document rather than implementation code, these are documentation quality issues that don't affect functionality
    • No files require special attention

    Important Files Changed

    Filename Overview
    doc/dev/ragged_iter_domain_design_doc.md New RFC document proposing RaggedIterDomain for nested tensors in nvFuser, covering architecture, API design, memory layout, and implementation phases

    Sequence Diagram

    sequenceDiagram
        participant User
        participant Fusion
        participant IrBuilder
        participant RaggedIterDomain
        participant TensorView
        participant Indexer
        participant CodeGen
        
        User->>Fusion: "Create nested tensor fusion"
        Fusion->>IrBuilder: "create IterDomains for components"
        IrBuilder-->>Fusion: "nested_domains[i]"
        Fusion->>IrBuilder: "create<RaggedIterDomain>(nested_domains)"
        IrBuilder->>RaggedIterDomain: "RaggedIterDomain(nested_domains)"
        RaggedIterDomain->>RaggedIterDomain: "validate uniform properties"
        RaggedIterDomain->>RaggedIterDomain: "compute offsets (cumulative sum)"
        RaggedIterDomain-->>IrBuilder: "ragged dimension"
        IrBuilder-->>Fusion: "ragged dimension"
        Fusion->>TensorView: "viewAsNested(data, offsets, ragged_dim)"
        TensorView->>TensorView: "create TensorDomain with RaggedIterDomain"
        TensorView-->>Fusion: "nested tensor view"
        Fusion->>Fusion: "apply transformations (split/merge)"
        Fusion->>Fusion: "parallelize ragged dimension"
        Fusion->>Indexer: "compute indices for ragged iteration"
        Indexer->>RaggedIterDomain: "get offsets for components"
        RaggedIterDomain-->>Indexer: "offset[component_idx]"
        Indexer->>Indexer: "compute global_index = offset + local_index"
        Indexer-->>Fusion: "ragged indices"
        Fusion->>CodeGen: "generate CUDA code"
        CodeGen->>CodeGen: "emit nested loop with offset-based indexing"
        CodeGen-->>User: "executable kernel"
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    Comment on lines +146 to +147
    // This overrides IterDomain::parallelize and calls nested_domains[i]->parallelize(pt) for all nested domains
    void parallelize(ParallelType pt);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: The parallelize() method signature should be void parallelize(ParallelType pt) override; to properly override the base class method

    Comment on lines +415 to +419
    struct NestedTensor {
    DT* ptr;
    int64_t extents[rank];
    int64_t nested_domain_extents[ragged_dimension_rank];
    };
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: Inconsistent indentation - mix of tabs and spaces in the struct definition. Use consistent indentation (spaces preferred in C++ code)

    #### RaggedIterDomain Class

    ```cpp
    class RaggedIterDomain : public IterDomain {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Note to myself: there's no way to find the offsets tensor from a RaggedIterDomain. Could that be a problem?

    **Split**: Split a regular IterDomain and merge with a RaggedIterDomain to create a new ragged structure.

    ```cpp
    auto split_result = IterDomain::split(ragged, 2);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    After EP dispatch, the activation tensor has shape [s,h] and the s is divided into g groups of different sizes. Then, this g is split evenly into d GPUs and each GPU gets exactly g/d groups (despite of non-uniform extents).

    Is ^^^ how split is defined here?

    Fundamentally, the extent of a RaggedIterDomain is 2D. It's unclear to me whether a split is applied to the number of groups or the extent of each group.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    My mental model for this has been

     i{s}  tv{offsets}
         \  /
      [RaggedSplit] # a new IterDomain op
         /  \
      i{g}  i{t} # tokens_per_expert
        |
       [Split]
      /       \
    iDIDx{d}  i{g/d}
    

    without having to introduce a subclass of IterDomain.

    I haven't yet figured out how I can represent this using RaggedIterDomain.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants