Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUInThreadTranspose();
mlir::registerTritonAMDGPUCoalesceAsyncCopy();
mlir::registerTritonAMDGPUUpdateAsyncWaitCount();
mlir::registerTritonAMDGPUWarpPipeline();
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
mlir::registerTritonAMDFoldTrueCmpI();
Expand Down
8 changes: 7 additions & 1 deletion python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "pybind11/pybind11.h"
#include <pybind11/stl.h>

#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
Expand Down Expand Up @@ -798,7 +799,12 @@ void init_gluon_ir(py::module &&m) {
.def("create_async_tdm_wait", [](GluonOpBuilder &self, int num) {
ValueRange tokens;
self.create<ttag::AsyncTDMWait>(tokens, num);
});
})
.def("create_warp_pipeline_border",
[](GluonOpBuilder &self) {
auto border = self.create<ROCDL::SchedBarrier>(0);
border->setAttr("pipeline_border", self.getBuilder().getUnitAttr());
});;

py::class_<ttg::WarpSpecializeOp, OpState>(m, "WarpSpecializeOp",
py::module_local())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,7 @@ def buffer_atomic_xchg(ptr, offsets, value, mask=None, sem=None, scope=None, _se

return _buffer_atomic_rmw_impl('xchg', ptr, offsets, value, "cdna3", mask=mask, sem=sem, scope=scope,
_semantic=_semantic)

@builtin
def split_warp_pipeline(_semantic: GluonSemantic = None):
return _semantic.builder.create_warp_pipeline_border()
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,7 @@ def buffer_atomic_xchg(ptr, offsets, value, mask=None, sem=None, scope=None, _se

return _buffer_atomic_rmw_impl('xchg', ptr, offsets, value, "cdna4", mask=mask, sem=sem, scope=scope,
_semantic=_semantic)

@builtin
def split_warp_pipeline(_semantic: GluonSemantic = None):
return _semantic.builder.create_warp_pipeline_border()
9 changes: 7 additions & 2 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,15 @@ def make_ttgir(mod, metadata, options):
amd.passes.ttgpuir.add_in_thread_transpose(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
amd.passes.ttgpuir.add_reorder_instructions(pm)
if use_block_pingpong and options.num_stages > 1:
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages)

if knobs.amd.use_buffer_ops:
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch, knobs.amd.use_buffer_atomics)

if use_block_pingpong and options.num_stages > 1:
amd.passes.ttgpuir.add_warp_pipeline(pm, options.num_stages)

amd.passes.ttgpuir.add_fold_true_cmpi(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
Expand Down Expand Up @@ -276,13 +277,17 @@ def make_llir(src, metadata, options):
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if use_block_pingpong:
amd.passes.ttgpuir.add_warp_pipeline(pm, options.num_stages)
# custom_lds_size is an experimental parameter that defines amount of LDS available
# for one thread block. Measured in bytes.
#
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
# LDS size is determined by provided arch name.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
amd.passes.ttgpuir.add_warp_pipeline_conversion(pm)
passes.common.add_canonicalizer(pm)
passes.convert.add_scf_to_cf(pm)
passes.gluon.add_inliner(pm)
passes.convert.add_index_to_llvmir(pm)
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace mlir::triton::AMD {
/// @return created pass
std::unique_ptr<OperationPass<ModuleOp>>
createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
std::unique_ptr<OperationPass<ModuleOp>> createConvertWarpPipelinePass();

void runScalarizePackedFOpsPass(llvm::Function &F);

Expand Down
10 changes: 10 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,14 @@ def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-in
];
}

def ConvertWarpPipeline : Pass<"convert-warp-pipeline", "mlir::ModuleOp"> {
let summary = "Emit conditional barrier and inlines scf.execute_region for warp-pipeline";
let constructor = "mlir::triton::AMD::createConvertWarpPipelinePass()";

let dependentDialects = ["mlir::LLVM::LLVMDialect",
"mlir::ROCDL::ROCDLDialect",
"mlir::triton::amdgpu::TritonAMDGPUDialect"];

}

#endif
16 changes: 16 additions & 0 deletions third_party/amd/include/TritonAMDGPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,20 @@ def TritonAMDGPUOptimizeDotOperands : Pass<"tritonamdgpu-optimize-dot-operands",
];
}

def TritonAMDGPUWarpPipeline: Pass<"tritonamdgpu-warp-pipeline", "mlir::ModuleOp"> {
let summary = "partition and pipeline";

let description = [{
This pass reorder instructions to interleave instructions from two warps on the same SIMD unit.
}];

let dependentDialects = ["mlir::ROCDL::ROCDLDialect, mlir::triton::amdgpu::TritonAMDGPUDialect"];

let options = [
Option<"numStages", "num-stages",
"int32_t", /*default*/"2",
"Number of Pipeline stages">,
];
}

#endif
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ add_triton_library(TritonAMDGPUToLLVM
BufferOpsEmitter.cpp
TensorPtrOpsToLLVM.cpp
ConvertLayoutOpToLLVM.cpp
ConvertWarpPipeline.cpp
MemoryOpToLLVM.cpp
MaskedOpsToLLVM.cpp
DotOpToLLVM/FMA.cpp
Expand Down
246 changes: 246 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
* Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "TargetInfo.h"
#include "TritonAMDGPUToLLVM/Passes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
#include "triton/Analysis/Membar.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

#define DEBUG_TYPE "convert-warp-pipeline"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

using namespace mlir;

namespace mlir::triton {
#define GEN_PASS_DEF_CONVERTWARPPIPELINE
#include "TritonAMDGPUToLLVM/Passes.h.inc"
} // namespace mlir::triton

namespace {

static BlockInfo buildBlockInfoFromBlock(Block *block, Allocation *allocation) {
BlockInfo info; // running fact for this block
for (Operation &opRef : *block) {
Operation *op = &opRef;
if (auto mei = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effs;
mei.getEffects(effs);
for (auto &eff : effs) {
if (Value v = eff.getValue()) {
for (auto bufId : allocation->getBufferIds(v)) {
if (bufId == Allocation::InvalidBufferId)
continue;
auto interval = allocation->getAllocatedInterval(bufId);
if (isa<MemoryEffects::Write>(eff.getEffect()))
info.syncWriteIntervals[interval].insert(op);
else if (isa<MemoryEffects::Read>(eff.getEffect()))
info.syncReadIntervals[interval].insert(op);
}
}
}
}
}
return info;
}

class ConvertWarpPipeline
: public mlir::triton::impl::ConvertWarpPipelineBase<ConvertWarpPipeline> {

void emitClusterBarrier(OpBuilder &b, Location loc, bool needLocal) {
b.create<ROCDL::SchedBarrier>(loc, 0);
if (needLocal)
b.create<mlir::triton::gpu::LocalBarrierOp>(loc);
else
b.create<ROCDL::SBarrierOp>(loc);
b.create<ROCDL::SchedBarrier>(loc, 0);
}

void emitPipelinedFor(OpBuilder &builder, Location loc, scf::ForOp forOp,
Allocation *allocation) {

// insert cond branch first,
builder.setInsertionPointAfter(forOp);
// Set barrier before starting the loop. This resolves any remaining
// required synchronization before beginning the specialized asymmetric
// synchronization.
auto preBarrier = builder.create<gpu::BarrierOp>(loc);
preBarrier->moveBefore(forOp);
builder.setInsertionPointAfter(preBarrier);

// Insert condbarrier::second_half before starting the loop
// FIXME : correctly calculate numbers by the given num_warps.
auto i32ty = builder.getIntegerType(32);
auto workIDX = builder.create<ROCDL::ThreadIdXOp>(loc, i32ty);
auto constZero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
auto constWarpSize = builder.create<arith::ConstantIntOp>(loc, 256, 32);
auto warpIDX = builder.create<arith::DivSIOp>(loc, workIDX, constWarpSize);
auto warpLow = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
warpIDX, constZero);
auto warpHigh = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
warpIDX, constZero);

// FIXME: duplicate condBarrier for lead_stages
auto condBarrierHigh =
builder.create<mlir::triton::amdgpu::CondBarrierOp>(loc, warpHigh);

// Insert condbarrier::first_half after the end of the loop
builder.setInsertionPointAfter(forOp);
auto condBarrierLow =
builder.create<mlir::triton::amdgpu::CondBarrierOp>(loc, warpLow);

// in case a loop begins with a barrier
bool barrierAtTop = false;

SmallVector<Block *> clusterBlocks;
SmallVector<Operation *> clusterOps;
SmallVector<bool> bars;
std::map<int, Operation *> existingBarrierMap;
Operation *terminatorOp;

for (auto &op : *forOp.getBody()) {
if (auto exeOp = dyn_cast<scf::ExecuteRegionOp>(op)) {
exeOp.setNoInline(false);
clusterOps.push_back(&op);
clusterBlocks.push_back(&exeOp->getRegion(0).front());
bars.push_back(false);
} else if (isa<ROCDL::BarrierOp, ROCDL::SBarrierOp,
triton::gpu::AsyncWaitOp>(op)) {
int currCluster = clusterBlocks.size();
if (existingBarrierMap.find(currCluster) != existingBarrierMap.end())
return; // FIXME: this is invalid. fail and cancel whole pass.

existingBarrierMap[currCluster] = &op;
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
terminatorOp = &op;
}
}

SmallVector<BlockInfo> clusterInfo;
for (auto cb : clusterBlocks)
clusterInfo.push_back(buildBlockInfoFromBlock(cb, allocation));

LDBG("cluster dependency analysis");
int numClusters = clusterInfo.size();
LDBG("total clusters : " << numClusters);

auto topBar = existingBarrierMap.find(0);
auto bottomBar = existingBarrierMap.find(numClusters);
if (bottomBar != existingBarrierMap.end()) {
if (topBar != existingBarrierMap.end())
return; // FIXME: unreachable
existingBarrierMap[0] = bottomBar->second;
existingBarrierMap.erase(bottomBar);
}

for (int j = 0; j < numClusters; j++) {
for (int i = 0; i < numClusters; i++) {
int next = (i + 2 + j) % numClusters;
int barLoc = (i + 1 + j) % numClusters;
int curr = (i + 1) % numClusters;
bool synced = false;
while (curr != i && curr != barLoc) {
if (bars[curr]) {
synced = true;
break;
}
curr = (curr + 1) % numClusters;
}
// also if next can already have a fence
if (bars[barLoc])
synced = true;

// synced between i and j, no need to check.
if (synced)
continue;

bool needFence =
clusterInfo[i].isIntersected(clusterInfo[next], nullptr);
if (needFence) {
// insert fence/barrier before this cluster
bars[barLoc] = true;
LDBG("cluster " << i << " need fence to " << next
<< " placing barrier at " << barLoc);
}
for (int i = 0; i < numClusters; i++)
LDBG("bars [" << i << "] = " << bars[i]);
}
}

for (int i = 0; i < numClusters; i++) {
if (auto exBar = existingBarrierMap.find(i);
exBar != existingBarrierMap.end()) {
if (bars[i]) {
auto exBarOp = exBar->second;
builder.setInsertionPointAfter(exBarOp);
emitClusterBarrier(builder, loc, true);
if (!isa<triton::gpu::AsyncWaitOp>(exBarOp))
exBarOp->erase();
} // else do nothing.
} else {
builder.setInsertionPoint(clusterOps[i]);
if (i == 0 && topBar == existingBarrierMap.end())
builder.setInsertionPoint(terminatorOp);
emitClusterBarrier(builder, loc, bars[i]);
}
}
}

public:
ConvertWarpPipeline() : ConvertWarpPipelineBase<ConvertWarpPipeline>() {}

void runOnOperation() override {

LDBG("cluster dependency analysis open");

ModuleOp m = getOperation();
OpBuilder builder(m);
ModuleAllocation moduleAllocation(m);

for (auto funcOp : m.getOps<mlir::triton::FuncOp>()) {
Allocation *allocation = moduleAllocation.getFuncData(funcOp);
funcOp.walk([&](scf::ForOp forOp) {
if (auto totalStages = forOp->getAttr("total_stages")) {
Location loc = forOp.getLoc();
emitPipelinedFor(builder, loc, forOp, allocation);
}
});
}
LDBG("cluster dependency analysis close");
}
};

} // namespace

namespace mlir::triton::AMD {

std::unique_ptr<OperationPass<ModuleOp>> createConvertWarpPipelinePass() {
return std::make_unique<ConvertWarpPipeline>();
}

} // namespace mlir::triton::AMD
3 changes: 2 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ struct ConvertTritonAMDGPUToLLVM

int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);

auto &err = llvm::errs();
err << "membar run\n";
// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);

Expand Down
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_triton_library(TritonAMDGPUTransforms
FoldTrueCmpIOp.cpp
UpdateAsyncWaitCount.cpp
Utility.cpp
WarpPipeliner.cpp

DEPENDS
TritonAMDGPUIR
Expand Down
Loading