@@ -772,6 +772,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
772772
773773 StringAttr kRegister = S (" register" );
774774 StringAttr kLane = S (" lane" );
775+ StringAttr kWarp = S (" warp" );
775776
776777 // https://github.com/ROCm/amd_matrix_instruction_calculator can print the
777778 // register and lane layout for mfma instructions.
@@ -814,30 +815,54 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
814815 {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }, /* gap*/ {0 , 8 }}}},
815816 {outDimNames[threadOrder[0 ]], outDimNames[threadOrder[1 ]]});
816817
818+ auto tilesPerWarp = getTilesPerWarp ();
819+ auto warpsPerCTA = getWarpsPerCTA ();
820+
821+ const unsigned tilesPerWarpM = tilesPerWarp[mIndex ];
822+ const unsigned tilesPerWarpN = tilesPerWarp[nIndex];
823+ const unsigned warpsPerCTAM = warpsPerCTA[mIndex ];
824+ const unsigned warpsPerCTAN = warpsPerCTA[nIndex];
825+
826+ auto warpOrder = getDefaultMmaOrder (*this );
827+ auto dimM = outDimNames[warpOrder[1 ]];
828+ auto dimN = outDimNames[warpOrder[0 ]];
829+ tileLayout = tileLayout.transposeOuts ({dimN, dimM});
830+
831+ // First, extend the layout along the N dimension:
832+ // - registers are distributed across tilesPerWarpN
833+ // - then across warpsPerCTAN in the N dimension.
834+ tileLayout *= LinearLayout::identity1D (tilesPerWarpN, kRegister , dimN);
835+ tileLayout *= LinearLayout::identity1D (warpsPerCTAN, kWarp , dimN);
836+
837+ // At this point, the layout is defined across the N dimension within a CTA
838+ // tile. Instead of switching to the M dimension now, we continue extending
839+ // the layout along the remaining N dimension, and only then proceed along M,
840+ // following the tilesPerWarp configuration.
841+ // If the N dimension is not large enough to span multiple CTA tiles (i.e.,
842+ // the first argument is 0), an empty layout is created, so this identity
843+ // layout will not introduce any new registers.
844+ tileLayout *= LinearLayout::identity1D (
845+ shape[nIndex] / (nDim * warpsPerCTAN * tilesPerWarpN), kRegister , dimN);
846+ tileLayout *= LinearLayout::identity1D (tilesPerWarpM, kRegister , dimM);
847+
848+ // Finally, extend the layout across warps in the M dimension.
849+ // After this step, the layout covers a sub-tensor of size ctaTileM × N,
850+ // i.e., the full N dimension and a CTA tile's extent in M.
851+ // The rest of the layout will be defined by combineCtaCgaWithShape.
852+ tileLayout *= LinearLayout::identity1D (warpsPerCTAM, kWarp , dimM);
853+
817854 if (hasBatchDim) {
818855 int batchIndex = 0 ;
819856 // Extend the base vector with one value to accommodate for the batch
820857 // dimension, which appears at the last.
821858 tileLayout *=
822859 LinearLayout::identity1D (1 , kRegister , outDimNames[batchIndex]);
823860 tileLayout *= LinearLayout::identity1D (1 , kLane , outDimNames[batchIndex]);
861+ tileLayout *= LinearLayout::identity1D (warpsPerCTA[0 ], kWarp ,
862+ outDimNames[batchIndex]);
824863 }
825864
826- // And each warp takes the same register and lane sub-layout. So multiply with
827- // an identity layout for the warp.
828- auto warpOrder = getDefaultMmaOrder (*this );
829- LinearLayout warpLayout =
830- identityStandardND (S (" warp" ), getWarpsPerCTA (), warpOrder);
831- // reorder dim names in rep order, so combineCtaCgaWithShape generate proper
832- // extension of layout
833- auto repOrder = getRepOrder ();
834- SmallVector<StringAttr> repDimNames;
835- for (auto dim : repOrder)
836- repDimNames.push_back (outDimNames[dim]);
837- LinearLayout ctaLayout = tileLayout.transposeOuts (repDimNames) *
838- warpLayout.transposeOuts (repDimNames);
839-
840- return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
865+ return combineCtaCgaWithShape (tileLayout, getCTALayout (), shape);
841866}
842867
843868LinearLayout wmmaDotOperandToLinearLayout (DotOperandEncodingAttr dotWmmaLayout,
@@ -866,6 +891,13 @@ LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
866891
867892 auto mnkDim = wmmaLayout.getInstrShape ();
868893 auto kDim = mnkDim[2 ];
894+ auto warpsPerCTA = wmmaLayout.getWarpsPerCTA ();
895+ auto tilesPerWarp = wmmaLayout.getTilesPerWarp ();
896+ auto nonKDimIndex = dotWmmaLayout.getOpIdx () == 0 ? rank - 2 : rank - 1 ;
897+ auto tilePerWarpNonK = tilesPerWarp[nonKDimIndex];
898+ auto kDimIndex = dotWmmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
899+ unsigned kSize = shape[kDimIndex ];
900+
869901 auto nonKDim = dotWmmaLayout.getOpIdx () == 0 ? mnkDim[0 ] : mnkDim[1 ];
870902 auto kWidth = dotWmmaLayout.getKWidth ();
871903 constexpr int warpSize = 32 ;
@@ -883,8 +915,18 @@ LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
883915 LinearLayout::identity1D (nonKDim, kLane , dimNonK);
884916 tileLayout *= version == 1 ? LinearLayout::zeros1D (depth, kLane , dimK)
885917 : LinearLayout::identity1D (depth, kLane , dimK);
886- tileLayout *=
887- LinearLayout::identity1D (kDim / (depth * kWidth ), kRegister , dimK);
918+
919+ // When tilePerWarpNonK > 1, we can't rely on the traditional way to fill the
920+ // block along K. Instead, we need to manually fill the whole kSize, then
921+ // apply tilePerWarpNonK along nonK direction.
922+ int kTileSize = depth * kWidth ;
923+ if (tilePerWarpNonK > 1 ) {
924+ tileLayout *= LinearLayout::identity1D (std::max (kSize , kDim ) / kTileSize ,
925+ kRegister , dimK);
926+ tileLayout *= LinearLayout::identity1D (tilePerWarpNonK, kRegister , dimNonK);
927+ } else {
928+ tileLayout *= LinearLayout::identity1D (kDim / kTileSize , kRegister , dimK);
929+ }
888930
889931 if (hasBatchDim) {
890932 assert (order[2 ] == 0 );
@@ -895,11 +937,9 @@ LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
895937 }
896938
897939 // Generate warp layout
898- auto warpsPerCTA = wmmaLayout.getWarpsPerCTA ();
899940 auto warpOrder = getDefaultMmaOrder (wmmaLayout);
900- auto kDimIdx = dotWmmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
901941 LinearLayout warpLayout = broadcastedDotOperandLayout (
902- ctx, warpsPerCTA, warpOrder, kDimIdx , S (" warp" ));
942+ ctx, warpsPerCTA, warpOrder, kDimIndex , S (" warp" ));
903943
904944 // reorder dim names in rep order, so combineCtaCgaWithShape generate proper
905945 // extension of layout
@@ -1428,8 +1468,10 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef<int64_t> shape,
14281468}
14291469
14301470LinearLayout chooseScaledWmmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
1431- ArrayRef<unsigned > warpsPerCTA,
1432- ArrayRef<int64_t > dotOperandShape) {
1471+ ArrayRef<int64_t > dotOperandShape,
1472+ unsigned wmmaMDim,
1473+ ArrayRef<unsigned > tilesPerWarp,
1474+ ArrayRef<unsigned > warpsPerCTA) {
14331475 using basisT = std::vector<std::vector<int32_t >>;
14341476 unsigned rank = dotOperandShape.size ();
14351477 auto order = mlir::triton::gpu::getMatrixOrder (rank, /* rowMajor=*/ true );
@@ -1449,18 +1491,30 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
14491491 auto dimK = outDimNames[order[0 ]];
14501492 auto dimNonK = outDimNames[order[1 ]];
14511493
1452- // Each lane holds kWidth=4 consecutive values along the k dim.
1453- // The first 16 lanes are distributed along the non-k dim. We are not using
1454- // the remaining 16 lanes, so just let them duplicate values of the first 16
1455- // lanes. If the shape along the k dim is larger than kWidth, repeat this
1456- // pattern to fill the k dim.
1494+ // Each lane holds kWidth=4 consecutive values along the K dim.
1495+ // The first 16 lanes are distributed along the nonK dim.
14571496 unsigned scaleKWidth = 4 ;
14581497 auto kSize = dotOperandShape[1 ];
14591498 LinearLayout tileLayout =
14601499 LinearLayout::identity1D (scaleKWidth, kRegister , dimK) *
1461- LinearLayout::identity1D (16 , kLane , dimNonK) *
1462- LinearLayout::zeros1D (2 , kLane , dimK) *
1463- LinearLayout::identity1D (kSize / scaleKWidth, kRegister , dimK);
1500+ LinearLayout::identity1D (16 , kLane , dimNonK);
1501+
1502+ // If there's 1 tile per warp, we are not using the remaining 16 lanes, so
1503+ // just let them duplicate values of the first 16 lanes.
1504+ // Otherwise, we put consecutive values along the nonK dim in the remaining
1505+ // 16 lanes.
1506+ unsigned mnDim = dotOperandIdx == 0 ? rank - 2 : rank - 1 ;
1507+ unsigned tilePerWarpMN = tilesPerWarp[mnDim];
1508+ if (tilePerWarpMN > 1 ) {
1509+ assert (tilePerWarpMN == 2 && " TilesPerWarp > 2 is not supported." );
1510+ tileLayout *= LinearLayout::identity1D (tilePerWarpMN, kLane , dimNonK);
1511+ } else {
1512+ tileLayout *= LinearLayout::zeros1D (2 , kLane , dimNonK);
1513+ }
1514+
1515+ // If the shape along the K dim is larger than kWidth, repeat this
1516+ // pattern to fill the K dim.
1517+ tileLayout *= LinearLayout::identity1D (kSize / scaleKWidth, kRegister , dimK);
14641518
14651519 auto warpsPerCTANew = (dotOperandIdx == 1 )
14661520 ? SmallVector{warpsPerCTA[1 ], warpsPerCTA[0 ]}
0 commit comments