-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[AMD] Add Warp-Pipeline Support, Gluon and LLVM lowering. #8586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jungpark-mlir
wants to merge
44
commits into
triton-lang:main
Choose a base branch
from
jungpark-mlir:newpp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+806
−4
Open
Changes from all commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
ce47bac
PoC impl of warp pipeline part #1
jungpark-mlir b97a7ac
Merge branch 'triton-lang:main' into newpp
jungpark-mlir da9f34e
Add lowering warp pipeline
jungpark-mlir 8999aba
Merge branch 'triton-lang:main' into newpp
jungpark-mlir c81dd97
Merge branch 'triton-lang:main' into newpp
jungpark-mlir e370a53
update
jungpark-mlir 3ab6436
milestone 1
jungpark-mlir 12a521b
Merge branch 'triton-lang:main' into newpp
jungpark-mlir 84d44be
actually insert fence/barrier
jungpark-mlir 424c0e4
backup
jungpark-mlir 8d9e559
Merge branch 'triton-lang:main' into newpp
jungpark-mlir 596f46c
Refactorize code
jungpark-mlir 7bedeab
Fix inserting barrier
jungpark-mlir 58051cb
Add support for gluon
jungpark-mlir 50f41eb
Merge branch 'newpp2' into newpp
jungpark-mlir 2c695de
Merge pull request #5 from jungpark-mlir/newpp
jungpark-mlir a17b120
Fix compilation
jungpark-mlir 099bfe8
Merge branch 'triton-lang:main' into newpp2
jungpark-mlir 7f1ddf0
Merge branch 'newpp' into newpp2
jungpark-mlir 4697a26
Merge pull request #6 from jungpark-mlir/newpp2
jungpark-mlir 9fafe4b
Merge branch 'newpp3' into newpp
jungpark-mlir ac3c75d
Merge pull request #7 from jungpark-mlir/newpp
jungpark-mlir acfd61c
improve code
jungpark-mlir 0d7fcd2
Merge branch 'triton-lang:main' into newpp3
jungpark-mlir 89ca6c4
Merge branch 'newpp2' into newpp3
jungpark-mlir f298346
Merge pull request #8 from jungpark-mlir/newpp3
jungpark-mlir 4e69362
Add test and last clean up
jungpark-mlir 012678f
Merge branch 'triton-lang:main' into newpp2
jungpark-mlir 2715698
Fix leftovers
jungpark-mlir ef3f875
revert whitespace removal
jungpark-mlir fde1c36
Merge branch 'newpp' into newpp2
jungpark-mlir 75154e0
Merge pull request #9 from jungpark-mlir/newpp2
jungpark-mlir 90a2a7f
Improve implementation per review
jungpark-mlir c5315f4
Merge branch 'triton-lang:main' into newpp2
jungpark-mlir c1087af
Merge pull request #10 from jungpark-mlir/newpp2
jungpark-mlir 3e45471
Remove extra canonicalization.
jungpark-mlir 975a15c
Merge branch 'main' into newpp
jungpark-mlir f884827
Fix accidental mistype.
jungpark-mlir 885ee54
Remove unused option.
jungpark-mlir f90e809
Change to use `with` to define pipeline.
jungpark-mlir 85f4b21
Merge branch 'triton-lang:main' into newpp
jungpark-mlir b3fa756
Merge branch 'triton-lang:main' into newpp
jungpark-mlir 6bb284e
Fix test
jungpark-mlir 96a1e28
Format
jungpark-mlir File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,13 @@ | ||
| from .._core import builtin | ||
| from ._layouts import AMDMFMALayout, AMDWMMALayout | ||
| from . import cdna3, cdna4 | ||
| from . import rdna3, rdna4 | ||
| from . import gfx1250 | ||
| from .warp_pipeline import warp_pipeline_stage | ||
|
|
||
| __all__ = ["AMDMFMALayout", "AMDWMMALayout", "cdna3", "cdna4", "rdna3", "rdna4", "gfx1250"] | ||
| __all__ = ["AMDMFMALayout", "AMDWMMALayout", "cdna3", "cdna4", "rdna3", "rdna4", "gfx1250", "warp_pipeline_stage"] | ||
|
|
||
|
|
||
| @builtin | ||
| def split_warp_pipeline(_semantic=None): | ||
| return _semantic.builder.create_warp_pipeline_border() | ||
26 changes: 26 additions & 0 deletions
26
python/triton/experimental/gluon/language/amd/warp_pipeline.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from __future__ import annotations | ||
| from typing import Optional | ||
|
|
||
|
|
||
| class warp_pipeline_stage: | ||
| __slots__ = ("label", "_semantic") | ||
|
|
||
| def __init__(self, label: Optional[str] = None, **_internal_kwargs): | ||
| self.label = label | ||
| self._semantic = _internal_kwargs.pop("_semantic", None) | ||
|
|
||
| def __enter__(self): | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc, tb): | ||
| if exc_type is not None: | ||
| return False | ||
| try: | ||
| from . import split_warp_pipeline | ||
| try: | ||
| split_warp_pipeline(_semantic=self._semantic) | ||
| except TypeError: | ||
| split_warp_pipeline() | ||
| except Exception: | ||
| pass | ||
| return False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| // RUN: triton-opt %s -convert-warp-pipeline | FileCheck %s | ||
|
|
||
| // ---- 2-stage pipeline (basic) ---- | ||
|
|
||
| tt.func @two_stage_backend(%n: index) { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| // Frontend has already annotated total stages. | ||
| scf.for %i = %c0 to %n step %c1 { | ||
|
|
||
| // Stage 0 cluster | ||
| scf.execute_region { | ||
| %a0 = arith.addi %i, %c1 : index | ||
| %x0 = arith.addi %a0, %c1 : index | ||
| scf.yield | ||
| } {triton.warp_pipeline.stage} | ||
|
|
||
| // Stage 1 cluster | ||
| scf.execute_region { | ||
| %a1 = arith.addi %i, %c1 : index | ||
| %x1 = arith.muli %a1, %c1 : index | ||
| scf.yield | ||
| } {triton.warp_pipeline.stage} | ||
|
|
||
| scf.yield | ||
| } {triton.warp_pipeline.pipelined_for} | ||
|
|
||
| tt.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: tt.func @two_stage_backend( | ||
| // CHECK: %c0 = arith.constant 0 : index | ||
| // CHECK: %c1 = arith.constant 1 : index | ||
| // CHECK-NOT: no_inline | ||
|
|
||
| // === Pre-loop sync + role setup === | ||
| // CHECK: gpu.barrier | ||
| // CHECK: arith.divsi | ||
| // CHECK: %[[WARPLOW:.+]] = arith.cmpi eq | ||
| // CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne | ||
| // CHECK: amdg.cond_barrier %[[WARPHIGH]] | ||
|
|
||
| // CHECK: scf.for | ||
| // CHECK-NOT: scf.execute_region | ||
| // CHECK: rocdl.sched.barrier | ||
| // CHECK: rocdl.s.barrier | ||
| // CHECK: rocdl.sched.barrier | ||
| // CHECK-NOT: scf.execute_region | ||
|
|
||
| // CHECK: amdg.cond_barrier %[[WARPLOW]] | ||
| // CHECK: tt.return | ||
|
|
||
|
|
||
| // ---- 3-stage pipeline (ensures multiple clusters handled) ---- | ||
|
|
||
| tt.func @three_stage_backend(%n: index) { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| scf.for %i = %c0 to %n step %c1 { | ||
|
|
||
| // Stage 0 | ||
| scf.execute_region { | ||
| %x0 = arith.addi %i, %c1 : index | ||
| scf.yield | ||
| } {triton.warp_pipeline.stage} | ||
| // Stage 1 | ||
| scf.execute_region { | ||
| %x1 = arith.muli %i, %c1 : index | ||
| scf.yield | ||
| } {triton.warp_pipeline.stage} | ||
| // Stage 2 | ||
| scf.execute_region { | ||
| %x2 = arith.addi %i, %c1 : index | ||
| scf.yield | ||
| } {triton.warp_pipeline.stage} | ||
|
|
||
| scf.yield | ||
| } {triton.warp_pipeline.pipelined_for} | ||
|
|
||
| tt.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: tt.func @three_stage_backend( | ||
| // CHECK-NOT: no_inline | ||
| // CHECK: gpu.barrier | ||
| // CHECK: amdg.cond_barrier | ||
| // CHECK: scf.for | ||
| // CHECK-NOT: scf.execute_region | ||
| // CHECK: rocdl.sched.barrier | ||
| // CHECK: rocdl.s.barrier | ||
| // CHECK: rocdl.sched.barrier | ||
| // CHECK: amdg.cond_barrier | ||
| // CHECK: tt.return | ||
|
|
||
|
|
||
| // ---- Negative: no total_stages → pass should not touch the loop ---- | ||
|
|
||
| tt.func @no_total_stages(%n: index) { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
| scf.for %i = %c0 to %n step %c1 { | ||
| scf.execute_region { | ||
| %x = arith.addi %i, %c1 : index | ||
| scf.yield | ||
| } | ||
| scf.yield | ||
| } | ||
| tt.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: tt.func @no_total_stages( | ||
| // CHECK-NOT: gpu.barrier | ||
| // CHECK-NOT: amdg.cond_barrier | ||
| // CHECK: scf.for | ||
| // CHECK: scf.execute_region | ||
| // CHECK: tt.return |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| // RUN: triton-opt %s -tritonamdgpu-warp-pipeline | FileCheck %s | ||
|
|
||
| // ---- 3-stage example (two borders) ---- | ||
|
|
||
| tt.func @three_stage_example(%n: index) { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| scf.for %i = %c0 to %n step %c1 { | ||
| // Stage 0 (before first border) | ||
| %a = arith.addi %i, %c1 : index | ||
| %a2 = arith.muli %a, %c1 : index | ||
|
|
||
| // explicit split point → next stage begins | ||
| rocdl.sched.barrier 0 {triton.warp_pipeline.border} | ||
|
|
||
| // Stage 1 | ||
| %b = arith.addi %a2, %i : index | ||
|
|
||
| // explicit split point → next stage begins | ||
| rocdl.sched.barrier 0 {triton.warp_pipeline.border} | ||
|
|
||
| // Stage 2 | ||
| %c = arith.addi %b, %a : index | ||
| %d = arith.muli %c, %c1 : index | ||
|
|
||
| scf.yield | ||
| } | ||
|
|
||
| tt.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: tt.func @three_stage_example( | ||
| // CHECK: scf.for | ||
| // | ||
| // Inside the loop we expect exactly three execute_region clusters: | ||
| // CHECK: scf.execute_region | ||
| // CHECK: scf.execute_region | ||
| // CHECK: scf.execute_region | ||
| // CHECK: triton.warp_pipeline.pipelined_for | ||
| // | ||
| // And the split markers must be gone: | ||
| // CHECK-NOT: rocdl.sched.barrier | ||
| // CHECK: tt.return | ||
|
|
||
|
|
||
| // ---- 2-stage example (one border) ---- | ||
|
|
||
| tt.func @two_stage_example(%n: index) { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| scf.for %i = %c0 to %n step %c1 { | ||
| // Stage 0 | ||
| %x = arith.addi %i, %c1 : index | ||
|
|
||
| // split to Stage 1 | ||
| rocdl.sched.barrier 0 {triton.warp_pipeline.border} | ||
|
|
||
| // Stage 1 | ||
| %y = arith.muli %x, %c1 : index | ||
|
|
||
| scf.yield | ||
| } | ||
|
|
||
| tt.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: tt.func @two_stage_example( | ||
| // CHECK: scf.for | ||
| // CHECK: scf.execute_region | ||
| // CHECK: scf.execute_region | ||
| // CHECK: triton.warp_pipeline.pipelined_for | ||
| // CHECK-NOT: rocdl.sched.barrier | ||
| // CHECK: tt.return | ||
|
|
||
|
|
||
| // ---- Negative: no border → no structuring ---- | ||
|
|
||
| tt.func @no_split_example(%n: index) { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
|
|
||
| scf.for %i = %c0 to %n step %c1 { | ||
| %x = arith.addi %i, %c1 : index | ||
| %y = arith.muli %x, %c1 : index | ||
| scf.yield | ||
| } | ||
|
|
||
| tt.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: tt.func @no_split_example( | ||
| // CHECK: scf.for | ||
| // CHECK-NOT: scf.execute_region | ||
| // CHECK-NOT: pipelined_for | ||
| // CHECK: tt.return |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this makes sense as a concept in gluon. How will tensors work if they are executed on only one warp, but the tensor layout is expecting multiple warps?
gl.warp_specializeAPI work instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Warp-pipelining does not restrict execution to a single warp, nor does it change the program’s semantics. With or without warp-pipelining, all warps execute the same code and produce the same results. Warp-pipelining simply shifts the timing of when different warps execute different stages of a loop, allowing them to overlap memory and compute phases more effectively. In other words, it controls when warps run each stage, not what they execute.
Historically, we achieved the same effect using the block-pingpong pass, which has flattened IR, manually partitioned the loop, and inserted all synchronization in-place. That approach worked but did not scale: every new kernel or architecture required manual scheduling, barrier placement, and tuning. warp-pipeline replaces that ad-hoc strategy with a structured IR representation, enabling systematic pipeline analysis and automation.
Warp-pipelining is fundamentally different from warp-specialization
We're also reviewing the support for
warp-specializebut that's a separate effort.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so IIUC, this is all about emitting the
cond_barrierto delay some warps at runtime? Won't you just re-converge as soon as there is a barrier inside the loop? Also what difference do the stages make in that case?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. The key idea behind using
cond_barrieris that, warp groups diverge in time once they met acond_barrier, but they don’t need to reconverge at the same program counter. Once one group is delayed and the other runs ahead, they continue to “leap-frog” each other by hitting different barrier sites in the loop.This is because, the HW releases the barrier when all participating threads have reached a barrier but not necessarily the same barrier instruction. In other words, barriers are satisfied once every warp has arrived at the incoming barrier, even if those barriers are at different PCs. This allows two (or more) warp groups to keep synchronizing without ever reconverging, which naturally maintains the pipelined execution across iterations. At the end of the loop, we have to place a counter
cond_barrierto reconverge the warps.When I first explored this idea, I also assumed that barriers would only release when all threads reached the exact same barrier instruction. So I initially forced reconvergence in HW using a wrapper
llvm.funccontaining a barrier, ensuring warps could conditionally funnel to one canonical barrier point from different points in the kernel. That version also worked but it turned out to be unnecessary for the current HW behavior and thecond_barrierimplementation has been simplified as the current one.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So does that mean that for example a reduction or layout conversion inside the loop will fail because the barriers emitted during the lowering will not behave as expected?
Also you didn't answer why there need to be different stages. Could you get the same result by just exposing
cond_barrierdirectly in gluon?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the question is “why not just expose cond_barrier and related ops directly,” that’s essentially what
block-pingpongpass does today, andwarp-pipelineis the structured improvement of that approach. Two key improvements are,cond_barriercould work differently per each warp, which violates block based TTG programming modelwarp-pipeline. Across new kernels and hardware, block-pingpong repeatedly showed that dependency analysis and stage partitioning are both performance-critical and a common source of bugs.I think this is totally fair question. This PR only contains support for the Gluon, this may look like extra structure. The motivation, though, is to support auto-partitioning in the future and ensure a consistent IR design for both manual and automated pipelines.
And you're right,
warp-pipelinecannot simply work for the ops like reductions or layout conversions that inherently require warp-wide synchronization. It can only work if the sync points align with stage boundaries. In those cases, you structure the pipeline so the required barrier happens at the end of a stage. If the required synchronization cannot align with a stage boundary, then warp-pipeline simply isn’t appropriate for that loop. Users (and eventually auto-partitioning) decide when this optimization works effectively.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think I misunderstood something you said earlier. I was thinking
cond_barrierwas only participated in by the conditioned threads, but IIUC, actually all threads participate in the barrier but at different points in the program. So for the first barrier, the upper threads are blocked while the lower execute the first stage. Then the lower half waits at the end of the second stage for the upper half to finish the first stage and so on. I can see why that would be problematic.To be transparent though, I'm not sure I like the direction of having functions that are really annotations that change what's happening in the surrounding context. I'm fine with having new language constructs in general, I just want to make sure it fits in with a reasonable programming model. I especially don't like that the legality of using a function would depends on what other functions are called in the same loop...
Would you be open to implementing a different syntax? Something like:
IMO this makes it clearer that the operations are happening within a warp-pipelined context.
I also think you should raise an error if any of the ops require a barrier, as silent correctness issues that depend on implementation details of an op's lowering doesn't sound like a great user experience.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds a great idea to explore. Let me try.
Sure, that's definitely in the plan. It's not always impossible to use
warp-pipelinewith those algorithms but depends on the dependency between the operation and its users, e.g., it's fine if synchronization can be deferred to the cluster boundary. Current analysis only looks for the dependency requires local memory fence but checking illegal dependency will be added.