Skip to content

Commit c2d3513

Browse files
knwngjwu10003
authored andcommitted
[AMD] Implement Scale Preshuffling and opSel on GFX1250 (triton-lang#8576)
Following triton-lang#7603, this PR implemented scale preshuffling on gfx1250 for efficient memory access and better wmma codegen with `opSel`. As an example, in a mxfp GEMM kernel with `BLOCK_M x BLOCK_N x BLOCK_K`, scaleA's shape is `BLOCK_M x (BLOCK_K // 32)`. We preshuffle it to be `(BLOCK_M // 128) x (BLOCK_K x 4)` outside the kernel for better vectorization, and 'unshuffle' it inside the kernel to get canonical input to `wmma_scaled` op. Same to scaleB. Besides, 16x16x128 scaled wmma instruction reads scales only from the first 16 lanes in a wave, which is a waste of reading capacity. Therefore we use `opSel` to control wmma instruction to read scales from the first or last 16 lanes in a wave. So that we can read scales with all the lanes in a wave. To correctly issue wmma instructions with `opSel`, we need to group 2 consecutive wmma instruction tiles in a wave. This is done by introducing `tilesPerWarp` to `AMDWmmaEncodingAttr`, to avoid composing linear layout in gluon kernel all the time. This PR also includes the support for inferring padded shared layout for MemDescReshapeOp, because in case of async/tensor load, we need to do the 'unshuffling' on memory subview.
1 parent a9d7f2b commit c2d3513

File tree

12 files changed

+432
-107
lines changed

12 files changed

+432
-107
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,10 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
125125
ArrayRef<unsigned> warpsPerCTA);
126126

127127
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
128-
ArrayRef<unsigned> warpsPerCTA,
129-
ArrayRef<int64_t> dotOperandShape);
128+
ArrayRef<int64_t> dotOperandShape,
129+
unsigned wmmaMDim,
130+
ArrayRef<unsigned> tilesPerWarp,
131+
ArrayRef<unsigned> warpsPerCTA);
130132

131133
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx,
132134
ArrayRef<int64_t> shape, int opIdx,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ Example 4:
11331133
This example demonstrates semantics of tilesPerWarp parameter. The MFMA layout (with tilesPerWarp=[1,1])
11341134
assumes that each warp within a CTA tile computes a single MFMA tile. When the tensor is larger than
11351135
a single CTA tile, these tiles are repeated across the tensor. In this setup, the output tiles computed
1136-
by each wave were strided by the number of warps per CTA tile in both row and column dimensions.
1136+
by each warp were strided by the number of warps per CTA tile in both row and column dimensions.
11371137

11381138
For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the distribution of warps across the MFMA
11391139
tiles looked like:
@@ -1214,11 +1214,12 @@ It is characterized by the following parameters:
12141214
- 2: RDNA4; e.g., gfx1200, gfx1201
12151215
- 3: gfx1250
12161216
- `warpsPerCTA` indicates the warp layout in the block.
1217+
- `tilesPerWarp` The tile layout within a warp. Defaults to unit tile layout, i.e., single tile on all dimensions.
12171218
- `instrShape` indicates the shape in the form of (M, N, K) of the matrix
12181219
operation performed by a single WMMA instruction. Defaults to (16, 16, 16).
12191220
- `isTransposed` indicates the layout of the result tensor is transposed.
12201221

1221-
Example:
1222+
Example 1:
12221223
Suppose we have a tensor with shape [32, 64], `warpsPerCTA` set to [2, 2].
12231224
Matrix elements represent which lane owns the element. Currently only wave32 mode
12241225
is supported.
@@ -1292,20 +1293,59 @@ Row |
12921293
.. | ... ...
12931294
30 |[14 14 14 14 14 14 14 14 30 ... 30] [14 14 14 ... 30]
12941295
31 |[15 15 15 15 15 15 15 15 31 ... 31] [15 15 15 ... 31]
1296+
1297+
Example 2:
1298+
This example demonstrates the tilesPerWarp parameter, which shares the same sematics with
1299+
AMDMfmaEncodingAttr.
1300+
1301+
By default, WMMA layout assumes that each warp within a CTA tile computes a single WMMA tile.
1302+
When the tensor is larger than a single CTA tile, these tiles are repeated across the tensor.
1303+
In this setup, the output tiles computed by each warp are strided by the number of warps per CTA
1304+
tile in both row and column dimensions.
1305+
1306+
For instance, with 16 WMMA tiles and warpsPerCTA = [2, 2], the default(tilesPerWarp = [1, 1])
1307+
distribution of warps across the WMMA tiles looked like:
1308+
1309+
w0 w1 w0 w1
1310+
w2 w3 w2 w3
1311+
w0 w1 w0 w1
1312+
w2 w3 w2 w3
1313+
1314+
* Each unit reprsents a WMMA tile. w* shows which warp occupies that WMMA tile.
1315+
1316+
tilesPerWarp parameter allows each warp to compute contiguous WMMA tiles in the row and/or column dimensions.
1317+
Using the same example with tilesPerWarp = [2, 2], the layout becomes:
1318+
1319+
w0 w0 w1 w1
1320+
w0 w0 w1 w1
1321+
w2 w2 w3 w3
1322+
w2 w2 w3 w3
12951323
}];
12961324

12971325
let parameters = (
12981326
ins
12991327
"unsigned": $version,
13001328
"bool":$isTransposed,
13011329
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1330+
ArrayRefParameter<"unsigned">:$tilesPerWarp,
13021331
"CTALayoutAttr":$CTALayout,
13031332
ArrayRefParameter<"unsigned">:$instrShape
13041333
);
13051334

13061335
let genVerifyDecl = 1;
13071336
let hasCustomAssemblyFormat = 1;
13081337

1338+
let builders = [
1339+
AttrBuilder<(ins "unsigned":$version,
1340+
"bool":$isTransposed,
1341+
"ArrayRef<unsigned>":$warpsPerCTA,
1342+
"CTALayoutAttr":$CTALayout,
1343+
"ArrayRef<unsigned>":$instrShape), [{
1344+
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1345+
return $_get(context, version, isTransposed, warpsPerCTA, tilesPerWarp, CTALayout, instrShape);
1346+
}]>
1347+
];
1348+
13091349
let extraClassDeclaration = extraDistributedDeclaration # [{
13101350
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kDim, int opIdx) const;
13111351
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
@@ -1314,6 +1354,9 @@ Row |
13141354
return {16, 16, 16};
13151355
}
13161356

1357+
// Check if tilesPerWarp is 1 in every dimension.
1358+
bool hasUnitTilesPerWarp() const;
1359+
13171360
// Returns a swizzled shared layout matching this WMMA layout for the
13181361
// dot operand at the given |operandIdx| with |operandShape|.
13191362
SwizzledSharedEncodingAttr composeSharedLayoutForOperand(

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,9 @@ LogicalResult AMDMfmaEncodingAttr::verify(
12831283
//===----------------------------------------------------------------------===//
12841284
// WMMA encoding
12851285
//===----------------------------------------------------------------------===//
1286+
bool AMDWmmaEncodingAttr::hasUnitTilesPerWarp() const {
1287+
return llvm::all_of(getTilesPerWarp(), [](int x) { return x == 1; });
1288+
}
12861289

12871290
Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
12881291
if (parser.parseLess().failed())
@@ -1299,6 +1302,7 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
12991302
std::optional<SmallVector<unsigned>> CTAsPerCGA;
13001303
std::optional<SmallVector<unsigned>> CTASplitNum;
13011304
std::optional<SmallVector<unsigned>> CTAOrder;
1305+
SmallVector<unsigned> tilesPerWarp = {};
13021306
SmallVector<unsigned> instrShape = getDefaultInstrShape();
13031307

13041308
for (const NamedAttribute &attr : dict) {
@@ -1314,6 +1318,11 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
13141318
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
13151319
return {};
13161320
}
1321+
if (attr.getName() == "tilesPerWarp") {
1322+
if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp")
1323+
.failed())
1324+
return {};
1325+
}
13171326
if (attr.getName() == "CTAsPerCGA") {
13181327
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
13191328
.failed())
@@ -1342,9 +1351,12 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
13421351
if (!CTALayout.has_value())
13431352
return {};
13441353

1345-
return parser.getChecked<AMDWmmaEncodingAttr>(parser.getContext(), version,
1346-
isTransposed, warpsPerCTA,
1347-
*CTALayout, instrShape);
1354+
if (tilesPerWarp.empty())
1355+
tilesPerWarp = SmallVector<unsigned>(instrShape.size(), 1);
1356+
1357+
return parser.getChecked<AMDWmmaEncodingAttr>(
1358+
parser.getContext(), version, isTransposed, warpsPerCTA, tilesPerWarp,
1359+
*CTALayout, instrShape);
13481360
}
13491361

13501362
void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
@@ -1356,6 +1368,10 @@ void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
13561368
maybePrintCTALayout(getContext(), printer, getCTALayout(),
13571369
/*rank=*/getWarpsPerCTA().size());
13581370

1371+
auto tilesPerWarp = getTilesPerWarp();
1372+
if (!hasUnitTilesPerWarp())
1373+
printer << ", tilesPerWarp = [" << getTilesPerWarp() << "]";
1374+
13591375
if (getInstrShape() != ArrayRef(getDefaultInstrShape())) {
13601376
printer << ", instrShape = [" << getInstrShape() << "]";
13611377
}
@@ -1365,7 +1381,8 @@ void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
13651381
LogicalResult AMDWmmaEncodingAttr::verify(
13661382
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
13671383
bool isTransposed, llvm::ArrayRef<unsigned int> warpsPerCTA,
1368-
CTALayoutAttr ctaLayout, llvm::ArrayRef<unsigned> instrShape) {
1384+
llvm::ArrayRef<unsigned int> tilesPerWarp, CTALayoutAttr ctaLayout,
1385+
llvm::ArrayRef<unsigned> instrShape) {
13691386
if (!(version >= 1 && version <= 3))
13701387
return emitError() << "WMMA version must be in the [1, 3] range";
13711388

@@ -2172,7 +2189,7 @@ void AMDRotatingSharedEncodingAttr::print(AsmPrinter &printer) const {
21722189
// TODO: there is a lot of common code with MmaEncoding here
21732190

21742191
bool AMDMfmaEncodingAttr::hasUnitTilesPerWarp() const {
2175-
return !llvm::any_of(getTilesPerWarp(), [](int x) { return x != 1; });
2192+
return llvm::all_of(getTilesPerWarp(), [](int x) { return x == 1; });
21762193
}
21772194

21782195
SmallVector<int64_t>
@@ -2305,6 +2322,8 @@ AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape, int kDim,
23052322

23062323
assert(operandTileShape.size() == 2);
23072324
auto warpsPerCTA = getWarpsPerCTA();
2325+
auto tilesPerWarp = getTilesPerWarp();
2326+
23082327
auto rank = operandShape.size();
23092328
assert(rank == 2 || rank == 3);
23102329
int numRepBatch =
@@ -2313,15 +2332,19 @@ AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape, int kDim,
23132332
return {
23142333
numRepBatch,
23152334
std::max<int64_t>(1, operandShape[rank - 2] /
2316-
(operandTileShape[0] * warpsPerCTA[rank - 2])),
2335+
(operandTileShape[0] * tilesPerWarp[rank - 2] *
2336+
warpsPerCTA[rank - 2])) *
2337+
tilesPerWarp[rank - 2],
23172338
std::max<int64_t>(1, operandShape[rank - 1] / operandTileShape[1])};
23182339
else {
23192340
assert(opIdx == 1);
23202341
return {
23212342
numRepBatch,
23222343
std::max<int64_t>(1, operandShape[rank - 2] / operandTileShape[0]),
2323-
std::max<int64_t>(1, operandShape[rank - 1] / (operandTileShape[1] *
2324-
warpsPerCTA[rank - 1]))};
2344+
std::max<int64_t>(1, operandShape[rank - 1] /
2345+
(operandTileShape[1] * tilesPerWarp[rank - 1] *
2346+
warpsPerCTA[rank - 1])) *
2347+
tilesPerWarp[rank - 1]};
23252348
}
23262349
}
23272350

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 84 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
772772

773773
StringAttr kRegister = S("register");
774774
StringAttr kLane = S("lane");
775+
StringAttr kWarp = S("warp");
775776

776777
// https://github.com/ROCm/amd_matrix_instruction_calculator can print the
777778
// register and lane layout for mfma instructions.
@@ -814,30 +815,54 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
814815
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 8}}}},
815816
{outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]});
816817

818+
auto tilesPerWarp = getTilesPerWarp();
819+
auto warpsPerCTA = getWarpsPerCTA();
820+
821+
const unsigned tilesPerWarpM = tilesPerWarp[mIndex];
822+
const unsigned tilesPerWarpN = tilesPerWarp[nIndex];
823+
const unsigned warpsPerCTAM = warpsPerCTA[mIndex];
824+
const unsigned warpsPerCTAN = warpsPerCTA[nIndex];
825+
826+
auto warpOrder = getDefaultMmaOrder(*this);
827+
auto dimM = outDimNames[warpOrder[1]];
828+
auto dimN = outDimNames[warpOrder[0]];
829+
tileLayout = tileLayout.transposeOuts({dimN, dimM});
830+
831+
// First, extend the layout along the N dimension:
832+
// - registers are distributed across tilesPerWarpN
833+
// - then across warpsPerCTAN in the N dimension.
834+
tileLayout *= LinearLayout::identity1D(tilesPerWarpN, kRegister, dimN);
835+
tileLayout *= LinearLayout::identity1D(warpsPerCTAN, kWarp, dimN);
836+
837+
// At this point, the layout is defined across the N dimension within a CTA
838+
// tile. Instead of switching to the M dimension now, we continue extending
839+
// the layout along the remaining N dimension, and only then proceed along M,
840+
// following the tilesPerWarp configuration.
841+
// If the N dimension is not large enough to span multiple CTA tiles (i.e.,
842+
// the first argument is 0), an empty layout is created, so this identity
843+
// layout will not introduce any new registers.
844+
tileLayout *= LinearLayout::identity1D(
845+
shape[nIndex] / (nDim * warpsPerCTAN * tilesPerWarpN), kRegister, dimN);
846+
tileLayout *= LinearLayout::identity1D(tilesPerWarpM, kRegister, dimM);
847+
848+
// Finally, extend the layout across warps in the M dimension.
849+
// After this step, the layout covers a sub-tensor of size ctaTileM × N,
850+
// i.e., the full N dimension and a CTA tile's extent in M.
851+
// The rest of the layout will be defined by combineCtaCgaWithShape.
852+
tileLayout *= LinearLayout::identity1D(warpsPerCTAM, kWarp, dimM);
853+
817854
if (hasBatchDim) {
818855
int batchIndex = 0;
819856
// Extend the base vector with one value to accommodate for the batch
820857
// dimension, which appears at the last.
821858
tileLayout *=
822859
LinearLayout::identity1D(1, kRegister, outDimNames[batchIndex]);
823860
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[batchIndex]);
861+
tileLayout *= LinearLayout::identity1D(warpsPerCTA[0], kWarp,
862+
outDimNames[batchIndex]);
824863
}
825864

826-
// And each warp takes the same register and lane sub-layout. So multiply with
827-
// an identity layout for the warp.
828-
auto warpOrder = getDefaultMmaOrder(*this);
829-
LinearLayout warpLayout =
830-
identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder);
831-
// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
832-
// extension of layout
833-
auto repOrder = getRepOrder();
834-
SmallVector<StringAttr> repDimNames;
835-
for (auto dim : repOrder)
836-
repDimNames.push_back(outDimNames[dim]);
837-
LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) *
838-
warpLayout.transposeOuts(repDimNames);
839-
840-
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
865+
return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape);
841866
}
842867

843868
LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
@@ -866,6 +891,13 @@ LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
866891

867892
auto mnkDim = wmmaLayout.getInstrShape();
868893
auto kDim = mnkDim[2];
894+
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
895+
auto tilesPerWarp = wmmaLayout.getTilesPerWarp();
896+
auto nonKDimIndex = dotWmmaLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
897+
auto tilePerWarpNonK = tilesPerWarp[nonKDimIndex];
898+
auto kDimIndex = dotWmmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
899+
unsigned kSize = shape[kDimIndex];
900+
869901
auto nonKDim = dotWmmaLayout.getOpIdx() == 0 ? mnkDim[0] : mnkDim[1];
870902
auto kWidth = dotWmmaLayout.getKWidth();
871903
constexpr int warpSize = 32;
@@ -883,8 +915,18 @@ LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
883915
LinearLayout::identity1D(nonKDim, kLane, dimNonK);
884916
tileLayout *= version == 1 ? LinearLayout::zeros1D(depth, kLane, dimK)
885917
: LinearLayout::identity1D(depth, kLane, dimK);
886-
tileLayout *=
887-
LinearLayout::identity1D(kDim / (depth * kWidth), kRegister, dimK);
918+
919+
// When tilePerWarpNonK > 1, we can't rely on the traditional way to fill the
920+
// block along K. Instead, we need to manually fill the whole kSize, then
921+
// apply tilePerWarpNonK along nonK direction.
922+
int kTileSize = depth * kWidth;
923+
if (tilePerWarpNonK > 1) {
924+
tileLayout *= LinearLayout::identity1D(std::max(kSize, kDim) / kTileSize,
925+
kRegister, dimK);
926+
tileLayout *= LinearLayout::identity1D(tilePerWarpNonK, kRegister, dimNonK);
927+
} else {
928+
tileLayout *= LinearLayout::identity1D(kDim / kTileSize, kRegister, dimK);
929+
}
888930

889931
if (hasBatchDim) {
890932
assert(order[2] == 0);
@@ -895,11 +937,9 @@ LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
895937
}
896938

897939
// Generate warp layout
898-
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
899940
auto warpOrder = getDefaultMmaOrder(wmmaLayout);
900-
auto kDimIdx = dotWmmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
901941
LinearLayout warpLayout = broadcastedDotOperandLayout(
902-
ctx, warpsPerCTA, warpOrder, kDimIdx, S("warp"));
942+
ctx, warpsPerCTA, warpOrder, kDimIndex, S("warp"));
903943

904944
// reorder dim names in rep order, so combineCtaCgaWithShape generate proper
905945
// extension of layout
@@ -1428,8 +1468,10 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
14281468
}
14291469

14301470
LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
1431-
ArrayRef<unsigned> warpsPerCTA,
1432-
ArrayRef<int64_t> dotOperandShape) {
1471+
ArrayRef<int64_t> dotOperandShape,
1472+
unsigned wmmaMDim,
1473+
ArrayRef<unsigned> tilesPerWarp,
1474+
ArrayRef<unsigned> warpsPerCTA) {
14331475
using basisT = std::vector<std::vector<int32_t>>;
14341476
unsigned rank = dotOperandShape.size();
14351477
auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true);
@@ -1449,18 +1491,30 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14491491
auto dimK = outDimNames[order[0]];
14501492
auto dimNonK = outDimNames[order[1]];
14511493

1452-
// Each lane holds kWidth=4 consecutive values along the k dim.
1453-
// The first 16 lanes are distributed along the non-k dim. We are not using
1454-
// the remaining 16 lanes, so just let them duplicate values of the first 16
1455-
// lanes. If the shape along the k dim is larger than kWidth, repeat this
1456-
// pattern to fill the k dim.
1494+
// Each lane holds kWidth=4 consecutive values along the K dim.
1495+
// The first 16 lanes are distributed along the nonK dim.
14571496
unsigned scaleKWidth = 4;
14581497
auto kSize = dotOperandShape[1];
14591498
LinearLayout tileLayout =
14601499
LinearLayout::identity1D(scaleKWidth, kRegister, dimK) *
1461-
LinearLayout::identity1D(16, kLane, dimNonK) *
1462-
LinearLayout::zeros1D(2, kLane, dimK) *
1463-
LinearLayout::identity1D(kSize / scaleKWidth, kRegister, dimK);
1500+
LinearLayout::identity1D(16, kLane, dimNonK);
1501+
1502+
// If there's 1 tile per warp, we are not using the remaining 16 lanes, so
1503+
// just let them duplicate values of the first 16 lanes.
1504+
// Otherwise, we put consecutive values along the nonK dim in the remaining
1505+
// 16 lanes.
1506+
unsigned mnDim = dotOperandIdx == 0 ? rank - 2 : rank - 1;
1507+
unsigned tilePerWarpMN = tilesPerWarp[mnDim];
1508+
if (tilePerWarpMN > 1) {
1509+
assert(tilePerWarpMN == 2 && "TilesPerWarp > 2 is not supported.");
1510+
tileLayout *= LinearLayout::identity1D(tilePerWarpMN, kLane, dimNonK);
1511+
} else {
1512+
tileLayout *= LinearLayout::zeros1D(2, kLane, dimNonK);
1513+
}
1514+
1515+
// If the shape along the K dim is larger than kWidth, repeat this
1516+
// pattern to fill the K dim.
1517+
tileLayout *= LinearLayout::identity1D(kSize / scaleKWidth, kRegister, dimK);
14641518

14651519
auto warpsPerCTANew = (dotOperandIdx == 1)
14661520
? SmallVector{warpsPerCTA[1], warpsPerCTA[0]}

0 commit comments

Comments
 (0)