-
Notifications
You must be signed in to change notification settings - Fork 70
tma pointwise with broadcast #5555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: llu/pt3_auto1
Are you sure you want to change the base?
Conversation
|
Review updated until commit 0b5bd62 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
TMA Break Point Handling
|
9447394 to
83a809e
Compare
|
!test |
1 similar comment
|
!test |
Greptile OverviewGreptile SummaryThis PR extends the TMA pointwise scheduler to handle broadcast domains. The implementation refactors break point calculation into shared utilities ( Key changes:
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Scheduler as Pointwise Scheduler
participant Utils as pointwise_utils
participant TMA as TMA Scheduler
participant NonTMA as Non-TMA Scheduler
alt TMA Scheduling
Scheduler->>TMA: getPointwiseHeuristics()
TMA->>Utils: getBreakPoint(is_tma=true)
Utils->>Utils: Calculate broadcast multiples
Utils->>Utils: Iterate break points
alt Broadcast detected
Utils->>Utils: Prioritize any transfer size savings
Note over Utils: No 10% threshold for TMA
end
Utils-->>TMA: BreakPointInfo (break_point, flip_grid_binding)
TMA->>TMA: Calculate TMA tile sizes
TMA->>TMA: schedulePointwise()
alt break_point == 0
TMA->>TMA: split(0, tma_domain_inner)
Note over TMA: 1D: [I0*I1] → [Do, Di]
else break_point > 0
Note over TMA: 2D: Keep [I0, I1] separate
TMA->>TMA: NVF_ERROR(n_valid_dims >= 2)
end
TMA->>TMA: Filter inputs by isTvSuitableForTma()
Note over TMA: Reject inputs with broadcast domains
else Non-TMA Scheduling
Scheduler->>NonTMA: getPointwiseHeuristics()
alt Sufficient parallelism
NonTMA->>Utils: getBreakPoint(is_tma=false)
Utils->>Utils: Calculate broadcast multiples
Utils->>Utils: Iterate break points
alt No broadcast or minimal savings
Utils->>Utils: Require 10% transfer size savings
Utils->>Utils: Check parallelization threshold
end
Utils-->>NonTMA: BreakPointInfo
NonTMA->>Utils: getBlockGridConfig()
Utils->>Utils: Calculate bdimx, bdimy, gdim_left, gdim_right
Utils-->>NonTMA: BlockGridConfig
else Insufficient parallelism
NonTMA->>NonTMA: Use 1D scheduling (break_point=0)
end
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| // dimension. outer tile size: don't exceed the outer TMA dimension size Both | ||
| // Both are subject to hardware constraints of 256 elements per dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: comment has formatting issue - appears to be an incomplete edit
| // dimension. outer tile size: don't exceed the outer TMA dimension size Both | |
| // Both are subject to hardware constraints of 256 elements per dimension. | |
| // - Inner tile size: ensure at least 2 tiles in the inner TMA dimension | |
| // - Outer tile size: don't exceed the outer TMA dimension size |
Following #5553
TMA Pointwise Scheduler: Broadcast Domain Handling
(1) TMA Load vs. General Load (ldg/ld.global)
The current TMA pointwise scheduler does not use TMA load for inputs with concretized broadcast domains.
Example:
Given three inputs:
tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], whereB1andB2are broadcast domains concretized toI1andI2:tv0andtv1will NOT be loaded with TMANote: This is a performance optimization. These inputs can be loaded with TMA, but only using a one-dimensional tile, as demonstrated in the newly added test.
(2) Break Point Selection
When broadcasts are present, the loop domain of the reference tv is merged to
[lhs, rhs]instead of flattening to a single dimension.Example:
Given
tv0[I1, B2] + tv1[B1, I2] + tv2[I1, I2], the break point is atpos-1, which separates[I1, I2]into[lhs, rhs].Break point selection differs between TMA and non-TMA versions:
Rationale: In the TMA version, we cannot safely merge broadcast and non-broadcast domains when creating 2D TMA domains and schedules, so we always break when broadcasts are present. See restrictions at #5556