Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -50,6 +51,146 @@ struct ConvertRawBufferCast final
}
};

/// Converts GatherToLDSOp when its memrefs change from sub-byte types
/// (e.g. f4E2M1FN) to byte-sized types (i8) during narrow type emulation.
/// The pattern linearizes multi-dimensional indices into the converted 1D
/// memref space and adjusts the transfer type accordingly.
struct ConvertGatherToLDS final : OpConversionPattern<amdgpu::GatherToLDSOp> {
using Base::Base;

LogicalResult
matchAndRewrite(amdgpu::GatherToLDSOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto origSrcType = cast<MemRefType>(op.getSrc().getType());
auto origDstType = cast<MemRefType>(op.getDst().getType());
auto newSrcType = cast<MemRefType>(adaptor.getSrc().getType());
auto newDstType = cast<MemRefType>(adaptor.getDst().getType());

// Only convert sub-byte element types.
if (origSrcType.getElementTypeBitWidth() >= 8 &&
origDstType.getElementTypeBitWidth() >= 8) {
return failure();
}

// If types didn't change, nothing to do.
if (newSrcType == origSrcType && newDstType == origDstType) {
return failure();
}

Location loc = op.getLoc();
int origSrcBits = origSrcType.getElementTypeBitWidth();
int newSrcBits = newSrcType.getElementTypeBitWidth();
int origDstBits = origDstType.getElementTypeBitWidth();
int newDstBits = newDstType.getElementTypeBitWidth();

// Only convert when the transfer vector's total bits are a multiple of
// the new element bit width. E.g. vector<3xf4E2M1FN> (12 bits) cannot
// be cleanly packed into i8 elements.
if (auto vecType = dyn_cast<VectorType>(op.getTransferType())) {
int64_t totalBits =
vecType.getNumElements() * vecType.getElementTypeBitWidth();
if (totalBits % newSrcBits != 0) {
return rewriter.notifyMatchFailure(
op,
"transfer vector bit-width is not a multiple of the new element "
"bit width");
}
}

// Linearize source indices into a 1D byte-offset index.
Value srcIdx = linearizeAndPack(rewriter, loc, op.getSrcIndices(),
origSrcType, origSrcBits, newSrcBits);
if (!srcIdx) {
return rewriter.notifyMatchFailure(
op, "failed to linearize source indices (dynamic or mismatched "
"strides/offset, or invalid bit-width ratio)");
}

// Linearize destination indices.
Value dstIdx = linearizeAndPack(rewriter, loc, op.getDstIndices(),
origDstType, origDstBits, newDstBits);
if (!dstIdx) {
return rewriter.notifyMatchFailure(
op, "failed to linearize destination indices (dynamic or mismatched "
"strides/offset, or invalid bit-width ratio)");
}

// Adjust transfer type to use the new element type.
Type newTransferType = convertTransferType(
rewriter.getContext(), op.getTransferType(), origSrcBits, newSrcBits);

auto newOp = amdgpu::GatherToLDSOp::create(
rewriter, loc, adaptor.getSrc(), ValueRange{srcIdx}, adaptor.getDst(),
ValueRange{dstIdx}, TypeAttr::get(newTransferType));
if (op.getAsync()) {
newOp.setAsync(true);
}

rewriter.eraseOp(op);
return success();
}

private:
// Linearizes multi-dimensional indices into a 1D index for the packed
// byte-addressable memref.
// linearIdx = offset + sum(idx[i] * stride[i])
// packedIdx = linearIdx / (newBits / origBits)
static Value linearizeAndPack(ConversionPatternRewriter &rewriter,
Location loc, ValueRange indices,
MemRefType origType, int origBits,
int newBits) {
auto [strides, offset] = origType.getStridesAndOffset();

// Fail if the offset or any stride is dynamic.
if (ShapedType::isDynamic(offset)) {
return nullptr;
}
for (int64_t stride : strides) {
if (ShapedType::isDynamic(stride)) {
return nullptr;
}
}

// Fail if the number of indices doesn't match the rank.
if (indices.size() != strides.size()) {
return nullptr;
}

// Linearize: offset + sum(idx[i] * stride[i]).
Value linearIdx = arith::ConstantIndexOp::create(rewriter, loc, offset);
for (auto [idx, stride] : llvm::zip(indices, strides)) {
Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
Value product = arith::MulIOp::create(rewriter, loc, idx, strideVal);
linearIdx = arith::AddIOp::create(rewriter, loc, linearIdx, product);
}

// Pack: convert from origBits-element units to newBits-element units.
if (origBits != newBits) {
if (newBits <= origBits || newBits % origBits != 0) {
return nullptr;
}
int64_t packRatio = newBits / origBits;
Value ratioVal = arith::ConstantIndexOp::create(rewriter, loc, packRatio);
linearIdx = arith::DivUIOp::create(rewriter, loc, linearIdx, ratioVal);
}

return linearIdx;
}

// Converts the transfer type from sub-byte elements to byte-sized elements,
// preserving the total transfer size in bits.
static Type convertTransferType(MLIRContext *context, Type origType,
int origBits, int newBits) {
if (auto vecType = dyn_cast<VectorType>(origType)) {
int64_t totalBits =
vecType.getNumElements() * vecType.getElementTypeBitWidth();
int64_t newElems = totalBits / newBits;
return VectorType::get({newElems}, IntegerType::get(context, newBits));
}
return IntegerType::get(context, newBits);
}
};

struct AMDGPUEmulateNarrowTypePass final
: impl::AMDGPUEmulateNarrowTypePassBase<AMDGPUEmulateNarrowTypePass> {
void getDependentDialects(DialectRegistry &registry) const override {
Expand All @@ -67,8 +208,8 @@ struct AMDGPUEmulateNarrowTypePass final
};
target.addDynamicallyLegalDialect<amdgpu::AMDGPUDialect>(
opLegalCallback);
patterns.add<ConvertRawBufferCast>(typeConverter,
patterns.getContext());
patterns.add<ConvertRawBufferCast, ConvertGatherToLDS>(
typeConverter, patterns.getContext());
};
if (failed(emulateNarrowType(getOperation(), /*disableAtomic=*/true,
populateAMDGPUPatterns))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,108 @@ func.func @dim_resolution_with_vector_emulation(%size: index) {
// Verify vector operations are emulated to i8 (8xi4 -> 4xi8)
// CHECK: vector.load %{{.*}} : memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
// CHECK: vector.store %{{.*}} : memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>

// -----

// Test that gather_to_lds with sub-byte element types (i4) gets converted
// to use byte-sized elements (i8), with indices packed and transfer type
// adjusted to preserve the same number of transferred bits.
func.func @gather_to_lds_i4_2d(
%src: memref<128x32xi4, #amdgpu.address_space<fat_raw_buffer>>,
%dst: memref<64xi4, #gpu.address_space<workgroup>>,
%i0: index, %i1: index, %j0: index) {
amdgpu.gather_to_lds %src[%i0, %i1], %dst[%j0]
: vector<8xi4>,
memref<128x32xi4, #amdgpu.address_space<fat_raw_buffer>>,
memref<64xi4, #gpu.address_space<workgroup>>
return
}

// CHECK-LABEL: func.func @gather_to_lds_i4_2d(
// CHECK-SAME: %[[SRC:.*]]: memref<2048xi8, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space<workgroup>>
// CHECK-SAME: %[[I0:.*]]: index, %[[I1:.*]]: index, %[[J0:.*]]: index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[MUL:.*]] = arith.muli %[[I0]], %[[C32]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[I1]]
// CHECK: %[[SRC_IDX:.*]] = arith.divui %[[ADD]], %[[C2]]
// CHECK: %[[DST_IDX:.*]] = arith.divui %[[J0]], %[[C2]]
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX]]], %[[DST]][%[[DST_IDX]]]
// CHECK-SAME: : vector<4xi8>

// -----

// Test gather_to_lds with async attribute is preserved after conversion.
func.func @gather_to_lds_i4_async(
%src: memref<256xi4, #amdgpu.address_space<fat_raw_buffer>>,
%dst: memref<64xi4, #gpu.address_space<workgroup>>,
%idx: index, %jdx: index) {
amdgpu.gather_to_lds async %src[%idx], %dst[%jdx]
: vector<8xi4>,
memref<256xi4, #amdgpu.address_space<fat_raw_buffer>>,
memref<64xi4, #gpu.address_space<workgroup>>
return
}

// CHECK-LABEL: func.func @gather_to_lds_i4_async(
// CHECK-SAME: %[[SRC:.*]]: memref<128xi8, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space<workgroup>>
// CHECK: amdgpu.gather_to_lds async
// CHECK-SAME: : vector<4xi8>

// -----

// Test gather_to_lds with f4E2M1FN sub-byte type gets converted
// (vector<8xf4E2M1FN> = 32 bits -> vector<4xi8>).
func.func @gather_to_lds_f4E2M1FN_2d(
%src: memref<128x32xf4E2M1FN, #amdgpu.address_space<fat_raw_buffer>>,
%dst: memref<64xf4E2M1FN, #gpu.address_space<workgroup>>,
%i0: index, %i1: index, %j0: index) {
amdgpu.gather_to_lds %src[%i0, %i1], %dst[%j0]
: vector<8xf4E2M1FN>,
memref<128x32xf4E2M1FN, #amdgpu.address_space<fat_raw_buffer>>,
memref<64xf4E2M1FN, #gpu.address_space<workgroup>>
return
}

// CHECK-LABEL: func.func @gather_to_lds_f4E2M1FN_2d(
// CHECK-SAME: %[[SRC:.*]]: memref<2048xi8, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space<workgroup>>
// CHECK-SAME: %[[I0:.*]]: index, %[[I1:.*]]: index, %[[J0:.*]]: index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[MUL:.*]] = arith.muli %[[I0]], %[[C32]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[I1]]
// CHECK: %[[SRC_IDX:.*]] = arith.divui %[[ADD]], %[[C2]]
// CHECK: %[[DST_IDX:.*]] = arith.divui %[[J0]], %[[C2]]
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX]]], %[[DST]][%[[DST_IDX]]]
// CHECK-SAME: : vector<4xi8>

// -----

// Test gather_to_lds with i2 sub-byte type gets converted (vector<16xi2> =
// 32 bits -> vector<4xi8>).
func.func @gather_to_lds_i2_2d(
%src: memref<128x64xi2, #amdgpu.address_space<fat_raw_buffer>>,
%dst: memref<128xi2, #gpu.address_space<workgroup>>,
%i0: index, %i1: index, %j0: index) {
amdgpu.gather_to_lds %src[%i0, %i1], %dst[%j0]
: vector<16xi2>,
memref<128x64xi2, #amdgpu.address_space<fat_raw_buffer>>,
memref<128xi2, #gpu.address_space<workgroup>>
return
}

// CHECK-LABEL: func.func @gather_to_lds_i2_2d(
// CHECK-SAME: %[[SRC:.*]]: memref<2048xi8, #amdgpu.address_space<fat_raw_buffer>>
// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space<workgroup>>
// CHECK-SAME: %[[I0:.*]]: index, %[[I1:.*]]: index, %[[J0:.*]]: index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
// CHECK: %[[MUL:.*]] = arith.muli %[[I0]], %[[C64]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[I1]]
// CHECK: %[[SRC_IDX:.*]] = arith.divui %[[ADD]], %[[C4]]
// CHECK: %[[DST_IDX:.*]] = arith.divui %[[J0]], %[[C4]]
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX]]], %[[DST]][%[[DST_IDX]]]
// CHECK-SAME: : vector<4xi8>
Loading