Skip to content

Conversation

@jungpark-mlir
Copy link
Contributor

@jungpark-mlir jungpark-mlir commented Oct 29, 2025

Enable Gluon kernels to express and compile warp-pipelined loops—where different warps execute staggered stages (e.g., load, compute, store)—to improve compute–memory overlap and utilization.

This is achieved through a structured, two-phase lowering pipeline:

  1. Frontend (Gluon → TritonGPU):
  • Adds a new API call: gl.amd.split_warp_pipeline(), which marks pipeline stage boundaries inside a Gluon kernel.
  • The new TritonAMDGPUWarpPipeline pass converts loops containing split points into structured scf.execute_region clusters, annotated with total_stages and lead_stages.
  1. Backend (TritonGPU → LLVM):
  • The ConvertWarpPipeline pass lowers each scf.execute_region cluster into predicated execution guarded by conditional barriers (amdgpu.cond_barrier).
  • Inserts scheduler and workgroup barriers (rocdl.sched.barrier, rocdl.s.barrier) to enforce correct cross-stage ordering and prevent instruction reordering.

Future work

  • Automatic partitioning frontend for Triton kernel : migrating legacy block-pingpong and entirely new partitioning pass

partitioning of the code into stages.
backup
can correctly insert fence.
update interfaces per recent changes
make it work actually
fix wrongly offset insertion
refactor loop
code cleanup
barrier should be inserted from the warp causing the dependency.
Added builtin split_warp_pipeline(), inserting the builtin
splits the code region into two pipeline clusters.
now runs on mi350
- polish conversion code
- found an important fix needed, just commented for now.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
amd.passes.ttgpuir.add_warp_pipeline_conversion(pm)
passes.common.add_canonicalizer(pm)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need another full canonicalization pass here? Might be better to do targeted cleanups

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Comment on lines 158 to 162
forOp->setAttr("triton.warp_pipeline.total_stages",
b.getI32IntegerAttr(totalStages));
forOp->setAttr("triton.warp_pipeline.lead_stages",
b.getI32IntegerAttr(1)); // TODO: make configurable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do those attributes control?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These don't do anything right now, I'll change them as a unit attribute to identify pipelined scf.for and will consider it again once I got a more concrete idea to use these.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, both removed and replaced with .pipelined_for.

cluster.push_back(op);
}
if (!cluster.empty())
clusters.push_back(std::move(cluster));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we create the regions directly rather than having a pass post process those?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically coming from considering how to program warp-pipeline in Gluon. First, I considered using python function to define a region as like warp-specialization but there were some issues, scf.execute_region doesn't have a block argument and Gluon user doesn't fully know the values required to be yield'ed. It might not be impossible to rewrite a python function into the scf.execute_region but required analysis might be even complicated than just defining clusters by the pipeline borders. Also border-based pipelining method can prevent user from mistakenly locating operations out of the clusters when pipelining.
This is also helpful when we migrate existing block-pingpong scheduling, this pass can be used for non-Gluon pass as well. New auto-partitioning will be directly creating regions, might be able to replace the others but not sure yet.

Comment on lines 222 to 237
void runOnOperation() override {
ModuleOp m = getOperation();
OpBuilder builder(m);
ModuleAllocation moduleAllocation(m);

for (auto funcOp : m.getOps<mlir::triton::FuncOp>()) {
Allocation *allocation = moduleAllocation.getFuncData(funcOp);
funcOp.walk([&](scf::ForOp forOp) {
if (auto totalStages =
forOp->getAttr("triton.warp_pipeline.total_stages")) {
Location loc = forOp.getLoc();
emitPipelinedFor(builder, loc, forOp, allocation);
}
});
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't the region be lowered by normal pattern rewrite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That could be better idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


@builtin
def split_warp_pipeline(_semantic=None):
return _semantic.builder.create_warp_pipeline_border()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this makes sense as a concept in gluon. How will tensors work if they are executed on only one warp, but the tensor layout is expecting multiple warps?

  1. Would the gl.warp_specialize API work instead?
  2. If not, could we generalize the API in a way that works also for AMD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warp-pipelining does not restrict execution to a single warp, nor does it change the program’s semantics. With or without warp-pipelining, all warps execute the same code and produce the same results. Warp-pipelining simply shifts the timing of when different warps execute different stages of a loop, allowing them to overlap memory and compute phases more effectively. In other words, it controls when warps run each stage, not what they execute.

Historically, we achieved the same effect using the block-pingpong pass, which has flattened IR, manually partitioned the loop, and inserted all synchronization in-place. That approach worked but did not scale: every new kernel or architecture required manual scheduling, barrier placement, and tuning. warp-pipeline replaces that ad-hoc strategy with a structured IR representation, enabling systematic pipeline analysis and automation.

Warp-pipelining is fundamentally different from warp-specialization

  • Warp-pipelining: all warps execute the same code; no functional divergence; only timing differs.
  • Warp-specialization: different warps run different roles or code paths (e.g., loader warp vs. compute warp), and there is no notion of pipeline stage ordering.

We're also reviewing the support for warp-specialize but that's a separate effort.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so IIUC, this is all about emitting the cond_barrier to delay some warps at runtime? Won't you just re-converge as soon as there is a barrier inside the loop? Also what difference do the stages make in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. The key idea behind using cond_barrier is that, warp groups diverge in time once they met a cond_barrier, but they don’t need to reconverge at the same program counter. Once one group is delayed and the other runs ahead, they continue to “leap-frog” each other by hitting different barrier sites in the loop.
This is because, the HW releases the barrier when all participating threads have reached a barrier but not necessarily the same barrier instruction. In other words, barriers are satisfied once every warp has arrived at the incoming barrier, even if those barriers are at different PCs. This allows two (or more) warp groups to keep synchronizing without ever reconverging, which naturally maintains the pipelined execution across iterations. At the end of the loop, we have to place a counter cond_barrier to reconverge the warps.

When I first explored this idea, I also assumed that barriers would only release when all threads reached the exact same barrier instruction. So I initially forced reconvergence in HW using a wrapper llvm.func containing a barrier, ensuring warps could conditionally funnel to one canonical barrier point from different points in the kernel. That version also worked but it turned out to be unnecessary for the current HW behavior and the cond_barrier implementation has been simplified as the current one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does that mean that for example a reduction or layout conversion inside the loop will fail because the barriers emitted during the lowering will not behave as expected?

Also you didn't answer why there need to be different stages. Could you get the same result by just exposing cond_barrier directly in gluon?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the question is “why not just expose cond_barrier and related ops directly,” that’s essentially what block-pingpong pass does today, and warp-pipeline is the structured improvement of that approach. Two key improvements are,

  1. cond_barrier could work differently per each warp, which violates block based TTG programming model
  2. Automate : currently it does the minimum analysis to determine the dependency across the pipeline stages which is slightly different from what Membar is currently does and eventually we'd like to use this IR structure and lowering method for auto-partitioning of the warp-pipeline. Across new kernels and hardware, block-pingpong repeatedly showed that dependency analysis and stage partitioning are both performance-critical and a common source of bugs.

I think this is totally fair question. This PR only contains support for the Gluon, this may look like extra structure. The motivation, though, is to support auto-partitioning in the future and ensure a consistent IR design for both manual and automated pipelines.

And you're right, warp-pipeline cannot simply work for the ops like reductions or layout conversions that inherently require warp-wide synchronization. It can only work if the sync points align with stage boundaries. In those cases, you structure the pipeline so the required barrier happens at the end of a stage. If the required synchronization cannot align with a stage boundary, then warp-pipeline simply isn’t appropriate for that loop. Users (and eventually auto-partitioning) decide when this optimization works effectively.

Copy link
Contributor

@peterbell10 peterbell10 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think I misunderstood something you said earlier. I was thinking cond_barrier was only participated in by the conditioned threads, but IIUC, actually all threads participate in the barrier but at different points in the program. So for the first barrier, the upper threads are blocked while the lower execute the first stage. Then the lower half waits at the end of the second stage for the upper half to finish the first stage and so on. I can see why that would be problematic.

To be transparent though, I'm not sure I like the direction of having functions that are really annotations that change what's happening in the surrounding context. I'm fine with having new language constructs in general, I just want to make sure it fits in with a reasonable programming model. I especially don't like that the legality of using a function would depends on what other functions are called in the same loop...

Would you be open to implementing a different syntax? Something like:

for i in range(...):
    with amd.warp_pipeline_stage():
        x = i + one
    with amd.warp_pipeline_stage():
            y = x * one
            x = y + one

IMO this makes it clearer that the operations are happening within a warp-pipelined context.

I also think you should raise an error if any of the ops require a barrier, as silent correctness issues that depend on implementation details of an op's lowering doesn't sound like a great user experience.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be open to implementing a different syntax? Something like:

That sounds a great idea to explore. Let me try.

I also think you should raise

Sure, that's definitely in the plan. It's not always impossible to use warp-pipeline with those algorithms but depends on the dependency between the operation and its users, e.g., it's fine if synchronization can be deferred to the cluster boundary. Current analysis only looks for the dependency requires local memory fence but checking illegal dependency will be added.

- Simplify discardable attr for marking pipeline
- Change to use pattern match to convert ops.
region is now inlined in the pass and no longer needed.
@antiagainst antiagainst marked this pull request as ready for review November 1, 2025 00:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants