-
Notifications
You must be signed in to change notification settings - Fork 70
[RFC] RaggedIterDomain for nested tensors #5550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit 70e0885 Description
|
| Relevant files | |||
|---|---|---|---|
| Documentation |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Implementation Complexity Underestimation
|
- Add select operation section explaining how to extract individual components - Fix section numbering (removed gaps in section sequence) - Fix broken link syntax for PyTorch documentation - Clarify flatten operation as TBD based on expert parallelism requirements - Remove references to non-existent appendices - Standardize terminology to "nested domains" throughout - Clarify DID as "distributed device (multi-GPU) parallelization" - Remove metadata header (Status, Author, Date fields) - Update PyTorch description to "one dimension" instead of "one or more" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Greptile Summary
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Fusion
participant IrBuilder
participant RaggedIterDomain
participant TensorView
participant Indexer
participant CodeGen
User->>Fusion: "Create nested tensor fusion"
Fusion->>IrBuilder: "create IterDomains for components"
IrBuilder-->>Fusion: "nested_domains[i]"
Fusion->>IrBuilder: "create<RaggedIterDomain>(nested_domains)"
IrBuilder->>RaggedIterDomain: "RaggedIterDomain(nested_domains)"
RaggedIterDomain->>RaggedIterDomain: "validate uniform properties"
RaggedIterDomain->>RaggedIterDomain: "compute offsets (cumulative sum)"
RaggedIterDomain-->>IrBuilder: "ragged dimension"
IrBuilder-->>Fusion: "ragged dimension"
Fusion->>TensorView: "viewAsNested(data, offsets, ragged_dim)"
TensorView->>TensorView: "create TensorDomain with RaggedIterDomain"
TensorView-->>Fusion: "nested tensor view"
Fusion->>Fusion: "apply transformations (split/merge)"
Fusion->>Fusion: "parallelize ragged dimension"
Fusion->>Indexer: "compute indices for ragged iteration"
Indexer->>RaggedIterDomain: "get offsets for components"
RaggedIterDomain-->>Indexer: "offset[component_idx]"
Indexer->>Indexer: "compute global_index = offset + local_index"
Indexer-->>Fusion: "ragged indices"
Fusion->>CodeGen: "generate CUDA code"
CodeGen->>CodeGen: "emit nested loop with offset-based indexing"
CodeGen-->>User: "executable kernel"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| // This overrides IterDomain::parallelize and calls nested_domains[i]->parallelize(pt) for all nested domains | ||
| void parallelize(ParallelType pt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: The parallelize() method signature should be void parallelize(ParallelType pt) override; to properly override the base class method
| struct NestedTensor { | ||
| DT* ptr; | ||
| int64_t extents[rank]; | ||
| int64_t nested_domain_extents[ragged_dimension_rank]; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Inconsistent indentation - mix of tabs and spaces in the struct definition. Use consistent indentation (spaces preferred in C++ code)
| #### RaggedIterDomain Class | ||
|
|
||
| ```cpp | ||
| class RaggedIterDomain : public IterDomain { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to myself: there's no way to find the offsets tensor from a RaggedIterDomain. Could that be a problem?
| **Split**: Split a regular IterDomain and merge with a RaggedIterDomain to create a new ragged structure. | ||
|
|
||
| ```cpp | ||
| auto split_result = IterDomain::split(ragged, 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After EP dispatch, the activation tensor has shape [s,h] and the s is divided into g groups of different sizes. Then, this g is split evenly into d GPUs and each GPU gets exactly g/d groups (despite of non-uniform extents).
Is ^^^ how split is defined here?
Fundamentally, the extent of a RaggedIterDomain is 2D. It's unclear to me whether a split is applied to the number of groups or the extent of each group.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My mental model for this has been
i{s} tv{offsets}
\ /
[RaggedSplit] # a new IterDomain op
/ \
i{g} i{t} # tokens_per_expert
|
[Split]
/ \
iDIDx{d} i{g/d}
without having to introduce a subclass of IterDomain.
I haven't yet figured out how I can represent this using RaggedIterDomain.
No description provided.