diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp index 20353cb54fd9..a721aada736d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp @@ -10,6 +10,7 @@ #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -50,6 +51,148 @@ struct ConvertRawBufferCast final } }; +/// Converts GatherToLDSOp when its memrefs change from sub-byte types +/// (e.g. f4E2M1FN) to byte-sized types (i8) during narrow type emulation. +/// The pattern linearizes multi-dimensional indices into the converted 1D +/// memref space and adjusts the transfer type accordingly. +struct ConvertGatherToLDS final : OpConversionPattern { + using Base::Base; + + LogicalResult + matchAndRewrite(amdgpu::GatherToLDSOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefType origSrcType = op.getSrc().getType(); + MemRefType origDstType = op.getDst().getType(); + auto newSrcType = cast(adaptor.getSrc().getType()); + auto newDstType = cast(adaptor.getDst().getType()); + + // Only convert sub-byte element types. + if (origSrcType.getElementTypeBitWidth() >= 8 && + origDstType.getElementTypeBitWidth() >= 8) { + return failure(); + } + + // If types didn't change, nothing to do. + if (newSrcType == origSrcType && newDstType == origDstType) { + return failure(); + } + + Location loc = op.getLoc(); + int origSrcBits = origSrcType.getElementTypeBitWidth(); + int newSrcBits = newSrcType.getElementTypeBitWidth(); + int origDstBits = origDstType.getElementTypeBitWidth(); + int newDstBits = newDstType.getElementTypeBitWidth(); + + // Only convert when the transfer vector's total bits are a multiple of + // the new element bit width. E.g. vector<3xf4E2M1FN> (12 bits) cannot + // be cleanly packed into i8 elements. + if (auto vecType = dyn_cast(op.getTransferType())) { + int64_t totalBits = + vecType.getNumElements() * vecType.getElementTypeBitWidth(); + if (totalBits % newSrcBits != 0) { + return rewriter.notifyMatchFailure( + op, + "transfer vector bit-width is not a multiple of the new element " + "bit width"); + } + } + + // Linearize source indices into a 1D byte-offset index. + Value srcIdx = linearizeAndPack(rewriter, loc, op.getSrcIndices(), + origSrcType, origSrcBits, newSrcBits); + if (!srcIdx) { + return rewriter.notifyMatchFailure( + op, "failed to linearize source indices (dynamic or mismatched " + "strides/offset, or invalid bit-width ratio)"); + } + + // Linearize destination indices. + Value dstIdx = linearizeAndPack(rewriter, loc, op.getDstIndices(), + origDstType, origDstBits, newDstBits); + if (!dstIdx) { + return rewriter.notifyMatchFailure( + op, "failed to linearize destination indices (dynamic or mismatched " + "strides/offset, or invalid bit-width ratio)"); + } + + // Adjust transfer type to use the new element type. + Type newTransferType = convertTransferType( + rewriter.getContext(), op.getTransferType(), origSrcBits, newSrcBits); + + amdgpu::GatherToLDSOp::create( + rewriter, loc, adaptor.getSrc(), ValueRange{srcIdx}, adaptor.getDst(), + ValueRange{dstIdx}, TypeAttr::get(newTransferType), op.getAsyncAttr()); + + rewriter.eraseOp(op); + return success(); + } + +private: + // Linearizes multi-dimensional indices into a 1D index for the packed + // byte-addressable memref. + // linearIdx = offset + sum(idx[i] * stride[i]) + // packedIdx = linearIdx / (newBits / origBits) + static Value linearizeAndPack(ConversionPatternRewriter &rewriter, + Location loc, ValueRange indices, + MemRefType origType, int origBits, + int newBits) { + auto [strides, offset] = origType.getStridesAndOffset(); + + // Fail if the offset or any stride is dynamic. Dynamic offsets could be + // supported if guaranteed byte-aligned, but that hasn't been needed yet. + if (ShapedType::isDynamic(offset)) { + return nullptr; + } + for (int64_t stride : strides) { + if (ShapedType::isDynamic(stride)) { + return nullptr; + } + } + + // Fail if the number of indices doesn't match the rank. + if (indices.size() != strides.size()) { + return nullptr; + } + + // Linearize: offset + sum(idx[i] * stride[i]). + Value linearIdx = arith::ConstantIndexOp::create(rewriter, loc, offset); + for (auto [idx, stride] : llvm::zip(indices, strides)) { + Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride); + Value product = arith::MulIOp::create(rewriter, loc, idx, strideVal); + linearIdx = arith::AddIOp::create(rewriter, loc, linearIdx, product); + } + + // Pack: convert from origBits-element units to newBits-element units. + if (origBits != newBits) { + if (newBits <= origBits || newBits % origBits != 0) { + return nullptr; + } + int64_t packRatio = newBits / origBits; + Value ratioVal = arith::ConstantIndexOp::create(rewriter, loc, packRatio); + linearIdx = arith::DivUIOp::create(rewriter, loc, linearIdx, ratioVal); + } + + return linearIdx; + } + + // Converts the transfer type from sub-byte elements to byte-sized elements, + // preserving the total transfer size in bits. The caller must ensure + // totalBits is a multiple of newBits (the op verifier enforces that + // transfer sizes are 8, 16, 32, 96, or 128 bits, all multiples of 8). + static Type convertTransferType(MLIRContext *context, Type origType, + int origBits, int newBits) { + if (auto vecType = dyn_cast(origType)) { + int64_t totalBits = + vecType.getNumElements() * vecType.getElementTypeBitWidth(); + assert(totalBits % newBits == 0 && + "transfer size must be a multiple of the new element bit width"); + int64_t newElems = totalBits / newBits; + return VectorType::get({newElems}, IntegerType::get(context, newBits)); + } + return IntegerType::get(context, newBits); + } +}; + struct AMDGPUEmulateNarrowTypePass final : impl::AMDGPUEmulateNarrowTypePassBase { void getDependentDialects(DialectRegistry ®istry) const override { @@ -67,8 +210,8 @@ struct AMDGPUEmulateNarrowTypePass final }; target.addDynamicallyLegalDialect( opLegalCallback); - patterns.add(typeConverter, - patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); }; if (failed(emulateNarrowType(getOperation(), /*disableAtomic=*/true, populateAMDGPUPatterns))) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_emulate_narrow_type.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_emulate_narrow_type.mlir index 0beb8faaa1dd..835339f9493d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_emulate_narrow_type.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_emulate_narrow_type.mlir @@ -42,3 +42,108 @@ func.func @dim_resolution_with_vector_emulation(%size: index) { // Verify vector operations are emulated to i8 (8xi4 -> 4xi8) // CHECK: vector.load %{{.*}} : memref>, vector<4xi8> // CHECK: vector.store %{{.*}} : memref>, vector<4xi8> + +// ----- + +// Test that gather_to_lds with sub-byte element types (i4) gets converted +// to use byte-sized elements (i8), with indices packed and transfer type +// adjusted to preserve the same number of transferred bits. +func.func @gather_to_lds_i4_2d( + %src: memref<128x32xi4, #amdgpu.address_space>, + %dst: memref<64xi4, #gpu.address_space>, + %i0: index, %i1: index, %j0: index) { + amdgpu.gather_to_lds %src[%i0, %i1], %dst[%j0] + : vector<8xi4>, + memref<128x32xi4, #amdgpu.address_space>, + memref<64xi4, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @gather_to_lds_i4_2d( +// CHECK-SAME: %[[SRC:.*]]: memref<2048xi8, #amdgpu.address_space> +// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space> +// CHECK-SAME: %[[I0:.*]]: index, %[[I1:.*]]: index, %[[J0:.*]]: index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[MUL:.*]] = arith.muli %[[I0]], %[[C32]] +// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[I1]] +// CHECK: %[[SRC_IDX:.*]] = arith.divui %[[ADD]], %[[C2]] +// CHECK: %[[DST_IDX:.*]] = arith.divui %[[J0]], %[[C2]] +// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX]]], %[[DST]][%[[DST_IDX]]] +// CHECK-SAME: : vector<4xi8> + +// ----- + +// Test gather_to_lds with async attribute is preserved after conversion. +func.func @gather_to_lds_i4_async( + %src: memref<256xi4, #amdgpu.address_space>, + %dst: memref<64xi4, #gpu.address_space>, + %idx: index, %jdx: index) { + amdgpu.gather_to_lds async %src[%idx], %dst[%jdx] + : vector<8xi4>, + memref<256xi4, #amdgpu.address_space>, + memref<64xi4, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @gather_to_lds_i4_async( +// CHECK-SAME: %[[SRC:.*]]: memref<128xi8, #amdgpu.address_space> +// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space> +// CHECK: amdgpu.gather_to_lds async +// CHECK-SAME: : vector<4xi8> + +// ----- + +// Test gather_to_lds with f4E2M1FN sub-byte type gets converted +// (vector<8xf4E2M1FN> = 32 bits -> vector<4xi8>). +func.func @gather_to_lds_f4E2M1FN_2d( + %src: memref<128x32xf4E2M1FN, #amdgpu.address_space>, + %dst: memref<64xf4E2M1FN, #gpu.address_space>, + %i0: index, %i1: index, %j0: index) { + amdgpu.gather_to_lds %src[%i0, %i1], %dst[%j0] + : vector<8xf4E2M1FN>, + memref<128x32xf4E2M1FN, #amdgpu.address_space>, + memref<64xf4E2M1FN, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @gather_to_lds_f4E2M1FN_2d( +// CHECK-SAME: %[[SRC:.*]]: memref<2048xi8, #amdgpu.address_space> +// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space> +// CHECK-SAME: %[[I0:.*]]: index, %[[I1:.*]]: index, %[[J0:.*]]: index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[MUL:.*]] = arith.muli %[[I0]], %[[C32]] +// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[I1]] +// CHECK: %[[SRC_IDX:.*]] = arith.divui %[[ADD]], %[[C2]] +// CHECK: %[[DST_IDX:.*]] = arith.divui %[[J0]], %[[C2]] +// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX]]], %[[DST]][%[[DST_IDX]]] +// CHECK-SAME: : vector<4xi8> + +// ----- + +// Test gather_to_lds with i2 sub-byte type gets converted (vector<16xi2> = +// 32 bits -> vector<4xi8>). +func.func @gather_to_lds_i2_2d( + %src: memref<128x64xi2, #amdgpu.address_space>, + %dst: memref<128xi2, #gpu.address_space>, + %i0: index, %i1: index, %j0: index) { + amdgpu.gather_to_lds %src[%i0, %i1], %dst[%j0] + : vector<16xi2>, + memref<128x64xi2, #amdgpu.address_space>, + memref<128xi2, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @gather_to_lds_i2_2d( +// CHECK-SAME: %[[SRC:.*]]: memref<2048xi8, #amdgpu.address_space> +// CHECK-SAME: %[[DST:.*]]: memref<32xi8, #gpu.address_space> +// CHECK-SAME: %[[I0:.*]]: index, %[[I1:.*]]: index, %[[J0:.*]]: index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[MUL:.*]] = arith.muli %[[I0]], %[[C64]] +// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[I1]] +// CHECK: %[[SRC_IDX:.*]] = arith.divui %[[ADD]], %[[C4]] +// CHECK: %[[DST_IDX:.*]] = arith.divui %[[J0]], %[[C4]] +// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SRC_IDX]]], %[[DST]][%[[DST_IDX]]] +// CHECK-SAME: : vector<4xi8>