diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp index 6db75f623fb2..dea61960ff60 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/AMDGPUEmulateNarrowType.cpp @@ -91,7 +91,9 @@ struct ConvertGatherToLDS final : OpConversionPattern { vecType.getNumElements() * vecType.getElementTypeBitWidth(); if (totalBits % newSrcBits != 0) { return rewriter.notifyMatchFailure( - op, "transfer vector bit-width is not a multiple of byte width"); + op, + "transfer vector bit-width is not a multiple of the new element " + "bit width"); } } @@ -100,7 +102,8 @@ struct ConvertGatherToLDS final : OpConversionPattern { origSrcType, origSrcBits, newSrcBits); if (!srcIdx) { return rewriter.notifyMatchFailure( - op, "failed to linearize source indices (dynamic strides)"); + op, "failed to linearize source indices (dynamic or mismatched " + "strides/offset, or invalid bit-width ratio)"); } // Linearize destination indices. @@ -108,7 +111,8 @@ struct ConvertGatherToLDS final : OpConversionPattern { origDstType, origDstBits, newDstBits); if (!dstIdx) { return rewriter.notifyMatchFailure( - op, "failed to linearize destination indices (dynamic strides)"); + op, "failed to linearize destination indices (dynamic or mismatched " + "strides/offset, or invalid bit-width ratio)"); } // Adjust transfer type to use the new element type. @@ -129,29 +133,31 @@ struct ConvertGatherToLDS final : OpConversionPattern { private: // Linearizes multi-dimensional indices into a 1D index for the packed // byte-addressable memref. - // linearIdx = sum(idx[i] * stride[i]) + // linearIdx = offset + sum(idx[i] * stride[i]) // packedIdx = linearIdx * origBits / newBits static Value linearizeAndPack(ConversionPatternRewriter &rewriter, Location loc, ValueRange indices, MemRefType origType, int origBits, int newBits) { - if (origBits == newBits) { - // No packing needed; if also 1D, just pass through. - if (indices.size() == 1) { - return indices.front(); - } - } - auto [strides, offset] = origType.getStridesAndOffset(); + // Fail if the offset or any stride is dynamic. + if (ShapedType::isDynamic(offset)) { + return nullptr; + } for (int64_t stride : strides) { if (ShapedType::isDynamic(stride)) { return nullptr; } } - // Linearize: sum(idx[i] * stride[i]). - Value linearIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); + // Fail if the number of indices doesn't match the rank. + if (indices.size() != strides.size()) { + return nullptr; + } + + // Linearize: offset + sum(idx[i] * stride[i]). + Value linearIdx = arith::ConstantIndexOp::create(rewriter, loc, offset); for (auto [idx, stride] : llvm::zip(indices, strides)) { Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride); Value product = arith::MulIOp::create(rewriter, loc, idx, strideVal); @@ -160,7 +166,9 @@ struct ConvertGatherToLDS final : OpConversionPattern { // Pack: convert from origBits-element units to newBits-element units. if (origBits != newBits) { - assert(newBits > origBits && newBits % origBits == 0); + if (newBits <= origBits || newBits % origBits != 0) { + return nullptr; + } int64_t packRatio = newBits / origBits; Value ratioVal = arith::ConstantIndexOp::create(rewriter, loc, packRatio); linearIdx = arith::DivUIOp::create(rewriter, loc, linearIdx, ratioVal);