Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,6 @@ class TargetInfoBase {
virtual bool supportStMatrix() const { return false; }
virtual bool isCuda() const { return false; }

// Annotate target specific information to local store operations during
// lowering to LLVM.
virtual void localStoreOpAnnotation(triton::gpu::LocalStoreOp op,
size_t localStoreOpCount,
Type type) const {}
// Annotate target specific information to local load operations during
// lowering to LLVM. `llLoadOp` is the generated LLVM load op.
virtual void localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp,
Expand Down
11 changes: 6 additions & 5 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,12 @@ SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
Location loc, RewriterBase &rewriter,
const TargetInfoBase &target);

void storeDistributedToShared(
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals, const SharedMemoryObject &smemObj, Location loc,
RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);
void storeDistributedToShared(triton::gpu::MemDescType dstTy,
RankedTensorType srcTy, Type elemLlvmTy,
ArrayRef<Value> srcVals,
const SharedMemoryObject &smemObj, Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target);

// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
// We might want to merge them at some point, but having to support
Expand Down
19 changes: 8 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@ using namespace mlir::triton::gpu;
// blocked -> shared.
// Swizzling in shared memory to avoid bank conflict. Normally used for
// A/B operands of dots.
void lowerDistributedToShared(
Location loc, Value src, Value dst, Value adaptorSrc,
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
void lowerDistributedToShared(Location loc, Value src, Value dst,
Value adaptorSrc,
const SharedMemoryObject &smemObj,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
const TargetInfoBase &targetInfo) {
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(dst.getType());
auto elemTy = typeConverter->convertType(srcTy.getElementType());

auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
targetInfo, llvmOpCount);
targetInfo);
}

LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
Expand Down Expand Up @@ -245,20 +246,16 @@ struct LocalStoreOpConversion
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getDst(),
llvmElemTy, rewriter);
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
std::pair<size_t, Type> llvmOpCount;
if (targetInfo.isCuda()) {
if (failed(lowerLocalStore(loc, ctx, regVal, memDescTy, smemObj, inVals,
typeConverter, rewriter, targetInfo))) {
return failure();
}
} else {
lowerDistributedToShared(loc, regVal, memDescVal, adaptor.getSrc(),
smemObj, typeConverter, rewriter, targetInfo,
&llvmOpCount);
smemObj, typeConverter, rewriter, targetInfo);
}

targetInfo.localStoreOpAnnotation(op, llvmOpCount.first,
llvmOpCount.second);
rewriter.eraseOp(op);
return success();
}
Expand Down
7 changes: 1 addition & 6 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,7 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
ArrayRef<Value> srcVals,
const SharedMemoryObject &smemObj, Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount) {
const TargetInfoBase &target) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
bool success = emitTransferBetweenRegistersAndShared(
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
Expand All @@ -807,10 +806,6 @@ void storeDistributedToShared(triton::gpu::MemDescType dstTy,
b.store(vec, vecAddr)
.setAlignment(vecTy.getNumElements() *
elemLlvmTy.getIntOrFloatBitWidth() / 8);
if (llvmOpCount) {
++(llvmOpCount->first);
llvmOpCount->second = vecTy;
}
});

if (!success)
Expand Down
96 changes: 0 additions & 96 deletions test/TritonGPU/amd/amd-instruction-sched.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1" -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=1" -triton-amdgpu-insert-instruction-sched-hints="variant=local_prefetch" -tritongpu-reduce-data-duplication -optimize-amd-lds-usage="target-arch=gfx942" -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm="arch=gfx942" -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1" -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=2" -triton-amdgpu-insert-instruction-sched-hints="variant=local_prefetch" -tritongpu-reduce-data-duplication -optimize-amd-lds-usage="target-arch=gfx942" -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm="arch=gfx942" -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritonamdgpu-accelerate-matmul="arch-generation-name=gfx942 matrix-instruction-size=16 kPack=1" -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=2" -triton-amdgpu-insert-instruction-sched-hints="variant=local_prefetch" -tritongpu-reduce-data-duplication -optimize-amd-lds-usage="target-arch=gfx942" -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm="arch=gfx942" -triton-amdgpu-lower-insert-instruction-sched-hints="arch=gfx942 num_stages=2" -debug-only="lower-insert-instruction-sched-hints" -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=1" | FileCheck %s -check-prefix=LABELING_PS_1
// RUN: triton-opt %s -convert-triton-to-tritongpu="target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64" -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline="num_stages=2" | FileCheck %s -check-prefix=LABELING_PS_2

module {
// INSTR_COUNT_NS1-LABEL: @test_dot_op
// INSTR_COUNT_NS2-LABEL: @test_dot_op
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: @test_dot_op
// LABELING_PS_1-LABEL: @test_dot_op
// LABELING_PS_2-LABEL: @test_dot_op
tt.func @test_dot_op(%lb : index, %ub : index, %step : index,
Expand Down Expand Up @@ -40,96 +34,6 @@ module {
%a = tt.load %a_ptr : tensor<128x32x!tt.ptr<f16>>
%b = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr<f16>>

// INSTR_COUNT_NS1: amdgpu.instruction_sched_hint
// INSTR_COUNT_NS1-SAME: isBufferLoadsAEnabled = false
// INSTR_COUNT_NS1-SAME: isBufferLoadsBEnabled = false
// INSTR_COUNT_NS1-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>>
// INSTR_COUNT_NS1-SAME: numDsWritesA = #amdgpu.InstCounter<0, none>
// INSTR_COUNT_NS1-SAME: numDsWritesB = #amdgpu.InstCounter<0, none>
// INSTR_COUNT_NS1-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS1-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// INSTR_COUNT_NS2: amdgpu.instruction_sched_hint
// INSTR_COUNT_NS2-SAME: isBufferLoadsAEnabled = false
// INSTR_COUNT_NS2-SAME: isBufferLoadsBEnabled = false
// INSTR_COUNT_NS2-SAME: numDsReadsA = #amdgpu.InstCounter<8, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsReadsB = #amdgpu.InstCounter<32, vector<1xf16>>
// INSTR_COUNT_NS2-SAME: numDsWritesA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numDsWritesB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numGlobalLoadsA = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numGlobalLoadsB = #amdgpu.InstCounter<4, vector<4xf16>>
// INSTR_COUNT_NS2-SAME: numMMAs = #amdgpu.InstCounter<16, tensor<32x32x8xf16>>

// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD:.+]]
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE:512]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA:8]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ:32]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_WRITE]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[VMEM_READ]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ:256]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[DS_READ]], 2, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.group.barrier [[MFMA]], 1, 0
// USE_LOCAL_PREFETCH_GLOBAL_LOAD: rocdl.sched.barrier [[SCHED_GUARD]]


// LABELING_PS_1: scf.for
// LABELING_PS_1: %[[REG0_OP0:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<0>}
// LABELING_PS_1: %[[REG0_OP1:.+]] = tt.load {{.*}} {OpIdx = #amdgpu.OpIdx<1>}
Expand Down
8 changes: 0 additions & 8 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ class HIPOptions:
#
# Current experimental scheduling variants:
#
# local-prefetch: implements instruction scheduling similar to the one from the ROCm Composable
# Kernel library. Note, this variant requires the use of buffer load/store ops
# and a special software pipelining style - i.e., 1x LDS and 1x register
# prefetch buffers for each GEMM tile.
# attention: enables a bunch of optimizations for attention kernels, including:
# - iglp 2 and sched.barrier around it
# - sink-insts-to-avoid-spills flag to avoid register spills
Expand Down Expand Up @@ -237,10 +233,6 @@ def make_ttgir(mod, metadata, options):
local_prefetch = knobs.amd.local_prefetch
use_async_copy = knobs.amd.use_async_copy

# The `local-prefetch` scheduling variant requires turning on buffer ops.
if options.schedule_hint == "local-prefetch":
global_prefetch = local_prefetch = 1

amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy)
if use_async_copy:
amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,28 +254,7 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
interleave for better instruction level parallelism.
}];

let arguments = (ins
TritonAMDGPU_SchedHintVariantAttr:$variant,
TritonAMDGPU_InstCounter:$numDsReadsA,
TritonAMDGPU_InstCounter:$numDsReadsB,
TritonAMDGPU_InstCounter:$numDsWritesA,
TritonAMDGPU_InstCounter:$numDsWritesB,
TritonAMDGPU_InstCounter:$numGlobalLoadsA,
TritonAMDGPU_InstCounter:$numGlobalLoadsB,
BoolAttr:$isBufferLoadsAEnabled,
BoolAttr:$isBufferLoadsBEnabled,
TritonAMDGPU_InstCounter:$numMMAs
);

let builders = [
OpBuilder<(ins "amdgpu::SchedHint":$variant), [{
auto ctx = $_state.getContext();
auto noneType = NoneType::get(ctx);
auto emptyAttr = amdgpu::InstCounterAttr::get(ctx, 0, noneType);
build($_builder, $_state, variant, emptyAttr, emptyAttr, emptyAttr, emptyAttr,
emptyAttr, emptyAttr, false, false, emptyAttr);
}]>
];
let arguments = (ins TritonAMDGPU_SchedHintVariantAttr:$variant);

let assemblyFormat = [{ attr-dict }];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,

for (auto op : tensor.getUsers()) {
if (auto localLoadOp = llvm::dyn_cast<triton::gpu::LocalLoadOp>(op)) {
const size_t numDsReadsCount =
repB * numRepNonK * numRepK * loadsPerThread;
setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy);

for (auto llLoad : llLoads) {
AMD::addLocalLoadNoAliasScope(localLoadOp, llLoad);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,6 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
}
}

for (auto op : tensor.getUsers()) {
if (auto localLoadOp = llvm::dyn_cast<triton::gpu::LocalLoadOp>(op)) {
const size_t numDsReadsCount =
repB * numRepNonK * numRepK * loadsPerThread;
setNumGeneratedDsReads(localLoadOp, numDsReadsCount, loadVecTy);
}
}

MLIRContext *ctx = wmmaLayout.getContext();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(loadedValues.size(), loadedValues[0].getType()));
Expand Down
4 changes: 0 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ struct DotOpMFMAConversionHelper {
ctx, SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

setNumGeneratedMMAs(op, mmaCount, maybeMfmaIntrinsic->mDim,
maybeMfmaIntrinsic->nDim, maybeMfmaIntrinsic->kDim,
elemtTy);

rewriter.replaceOp(op, res);
}

Expand Down
5 changes: 0 additions & 5 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,6 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor,
wmmaLayout.getContext(), SmallVector<Type>(fc.size(), dstElemTy));
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);

const size_t mmaCount = numRepB * numRepM * numRepN * numRepK;
setNumGeneratedMMAs(op, mmaCount, maybeWmmaIntrinsic->mDim,
maybeWmmaIntrinsic->nDim, maybeWmmaIntrinsic->kDim,
aElemTy);

rewriter.replaceOp(op, res);
return success();
}
Expand Down
Loading
Loading