Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ struct ConvertGatherToLDS final : OpConversionPattern<amdgpu::GatherToLDSOp> {
vecType.getNumElements() * vecType.getElementTypeBitWidth();
if (totalBits % newSrcBits != 0) {
return rewriter.notifyMatchFailure(
op, "transfer vector bit-width is not a multiple of byte width");
op,
"transfer vector bit-width is not a multiple of the new element "
"bit width");
}
}

Expand All @@ -100,15 +102,17 @@ struct ConvertGatherToLDS final : OpConversionPattern<amdgpu::GatherToLDSOp> {
origSrcType, origSrcBits, newSrcBits);
if (!srcIdx) {
return rewriter.notifyMatchFailure(
op, "failed to linearize source indices (dynamic strides)");
op, "failed to linearize source indices (dynamic or mismatched "
"strides/offset, or invalid bit-width ratio)");
}

// Linearize destination indices.
Value dstIdx = linearizeAndPack(rewriter, loc, op.getDstIndices(),
origDstType, origDstBits, newDstBits);
if (!dstIdx) {
return rewriter.notifyMatchFailure(
op, "failed to linearize destination indices (dynamic strides)");
op, "failed to linearize destination indices (dynamic or mismatched "
"strides/offset, or invalid bit-width ratio)");
}

// Adjust transfer type to use the new element type.
Expand All @@ -129,29 +133,31 @@ struct ConvertGatherToLDS final : OpConversionPattern<amdgpu::GatherToLDSOp> {
private:
// Linearizes multi-dimensional indices into a 1D index for the packed
// byte-addressable memref.
// linearIdx = sum(idx[i] * stride[i])
// linearIdx = offset + sum(idx[i] * stride[i])
// packedIdx = linearIdx * origBits / newBits
static Value linearizeAndPack(ConversionPatternRewriter &rewriter,
Location loc, ValueRange indices,
MemRefType origType, int origBits,
int newBits) {
if (origBits == newBits) {
// No packing needed; if also 1D, just pass through.
if (indices.size() == 1) {
return indices.front();
}
}

auto [strides, offset] = origType.getStridesAndOffset();

// Fail if the offset or any stride is dynamic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Why is this a failure?

if (ShapedType::isDynamic(offset)) {
return nullptr;
}
for (int64_t stride : strides) {
if (ShapedType::isDynamic(stride)) {
return nullptr;
}
}

// Linearize: sum(idx[i] * stride[i]).
Value linearIdx = arith::ConstantIndexOp::create(rewriter, loc, 0);
// Fail if the number of indices doesn't match the rank.
if (indices.size() != strides.size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can this happen?

return nullptr;
}

// Linearize: offset + sum(idx[i] * stride[i]).
Value linearIdx = arith::ConstantIndexOp::create(rewriter, loc, offset);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a utility function for this these days?

for (auto [idx, stride] : llvm::zip(indices, strides)) {
Value strideVal = arith::ConstantIndexOp::create(rewriter, loc, stride);
Value product = arith::MulIOp::create(rewriter, loc, idx, strideVal);
Expand All @@ -160,7 +166,9 @@ struct ConvertGatherToLDS final : OpConversionPattern<amdgpu::GatherToLDSOp> {

// Pack: convert from origBits-element units to newBits-element units.
if (origBits != newBits) {
assert(newBits > origBits && newBits % origBits == 0);
if (newBits <= origBits || newBits % origBits != 0) {
return nullptr;
}
int64_t packRatio = newBits / origBits;
Value ratioVal = arith::ConstantIndexOp::create(rewriter, loc, packRatio);
linearIdx = arith::DivUIOp::create(rewriter, loc, linearIdx, ratioVal);
Expand Down