@@ -522,13 +522,62 @@ class MergeTileGroupsByRotation : public Pass {
522522 if (partitioned_dim == inferred_dims_.end ()) {
523523 return false ;
524524 }
525+
526+ bool box_is_bulk = bulk_groups_.count (partitioned_dim->box ) > 0 ;
527+ bool from1_is_bulk = bulk_groups_.count (from_.at (1 )) > 0 ;
528+
529+ std::cout << " [MergeTileGroupsByRotation::condition] Checking expr: "
530+ << expr->front ()->toString () << std::endl;
531+ std::cout << " from[0] (partitioned): " << from_.at (0 )->toString ()
532+ << " extent="
533+ << from_.at (0 )->front ()->as <IterDomain>()->extent ()->toString ()
534+ << std::endl;
535+ std::cout << " from[1]: " << from_.at (1 )->toString () << " extent="
536+ << from_.at (1 )->front ()->as <IterDomain>()->extent ()->toString ()
537+ << std::endl;
538+ std::cout << " to[0]: " << to_.at (0 )->toString () << " extent="
539+ << to_.at (0 )->front ()->as <IterDomain>()->extent ()->toString ()
540+ << std::endl;
541+ std::cout
542+ << " partitioned_dim->box extent: "
543+ << partitioned_dim->box ->front ()->as <IterDomain>()->extent ()->toString ()
544+ << std::endl;
545+ std::cout << " partitioned_dim->box is bulk: " << box_is_bulk << std::endl;
546+ std::cout << " from[1] is bulk: " << from1_is_bulk << std::endl;
547+
525548 if (bulk_groups_.count (partitioned_dim->box ) == 0 ||
526549 bulk_groups_.count (from_.at (1 )) == 0 ) {
550+ std::cout << " Result: FALSE (not both bulk)" << std::endl;
527551 return false ;
528552 }
529553 NVF_ERROR (
530554 partitioned_dim->tile == partitioned_dim->box &&
531555 partitioned_dim->stride == nullptr );
556+
557+ // Check if merging would exceed the 256-element hardware limit
558+ auto box_extent = partitioned_dim->box ->front ()->as <IterDomain>()->extent ();
559+ auto from1_extent = from_.at (1 )->front ()->as <IterDomain>()->extent ();
560+ Val* merged_extent =
561+ SimplifyingIrBuilder::mulExpr (box_extent, from1_extent);
562+
563+ constexpr int64_t largest_dim_size = 256 ; // Hardware limitation
564+ Val* too_large_after_merge = SimplifyingIrBuilder::gtExpr (
565+ merged_extent, IrBuilder::create<Val>(largest_dim_size));
566+
567+ std::cout << " Merged box size would be: " << box_extent->toString ()
568+ << " * " << from1_extent->toString () << " = "
569+ << merged_extent->toString () << std::endl;
570+
571+ if (simplifyExpr (too_large_after_merge)->isTrue ()) {
572+ std::cout << " Result: FALSE (would exceed 256 element limit)"
573+ << std::endl;
574+ return false ;
575+ }
576+
577+ std::cout
578+ << " Result: TRUE - WILL MERGE (pattern matches and within limits!)"
579+ << std::endl;
580+
532581 return true ;
533582 }
534583
@@ -540,7 +589,23 @@ class MergeTileGroupsByRotation : public Pass {
540589 inferred_dims_.begin (), inferred_dims_.end (), [&](const auto & dim) {
541590 return dim.partitioned == from_.at (0 );
542591 });
592+
593+ std::cout << " [MergeTileGroupsByRotation] Merging tile groups:"
594+ << std::endl;
595+ std::cout
596+ << " Old box extent: "
597+ << partitioned_dim->box ->front ()->as <IterDomain>()->extent ()->toString ()
598+ << std::endl;
599+ std::cout << " Merging with: "
600+ << from_.at (1 )->front ()->as <IterDomain>()->extent ()->toString ()
601+ << std::endl;
602+
543603 auto new_box = merge (&id_graph, partitioned_dim->tile , from_.at (1 ));
604+
605+ std::cout << " New box extent: "
606+ << new_box->front ()->as <IterDomain>()->extent ()->toString ()
607+ << std::endl;
608+
544609 bulk_groups_.insert (new_box);
545610 inferred_dims_.emplace_back ();
546611 inferred_dims_.back ().partitioned = to_.at (0 );
@@ -593,16 +658,56 @@ run(
593658 nonbulk_groups, inferred_dims);
594659 MergeTileGroupsByRotation merge_tile_pass (bulk_groups, inferred_dims);
595660
661+ std::cout << " \n [INFER_ROLES DEBUG] Starting infer_roles passes" << std::endl;
662+ int iteration = 0 ;
596663 bool changed = true ;
597664 while (changed) {
598665 changed = false ;
599- changed = changed || bulk_pass.run (exprs);
600- changed = changed || nonbulk_pass.run (exprs);
601- changed = changed || striding_split_pass.run (exprs);
602- changed = changed || boxing_split_pass.run (exprs);
603- changed = changed || move_partitioned_pass.run (exprs);
604- changed = changed || merge_tile_pass.run (exprs);
666+ std::cout << " \n [INFER_ROLES DEBUG] === Iteration " << iteration++
667+ << " ===" << std::endl;
668+ std::cout << " Remaining exprs: " << exprs.size () << std::endl;
669+ std::cout << " Current inferred_dims: " << inferred_dims.size ()
670+ << std::endl;
671+ for (const auto & dim : inferred_dims) {
672+ std::cout << " " << dim << std::endl;
673+ }
674+
675+ bool b = bulk_pass.run (exprs);
676+ if (b)
677+ std::cout << " bulk_pass made changes" << std::endl;
678+ changed = changed || b;
679+
680+ b = nonbulk_pass.run (exprs);
681+ if (b)
682+ std::cout << " nonbulk_pass made changes" << std::endl;
683+ changed = changed || b;
684+
685+ b = striding_split_pass.run (exprs);
686+ if (b)
687+ std::cout << " striding_split_pass made changes" << std::endl;
688+ changed = changed || b;
689+
690+ b = boxing_split_pass.run (exprs);
691+ if (b)
692+ std::cout << " boxing_split_pass made changes" << std::endl;
693+ changed = changed || b;
694+
695+ b = move_partitioned_pass.run (exprs);
696+ if (b)
697+ std::cout << " move_partitioned_pass made changes" << std::endl;
698+ changed = changed || b;
699+
700+ b = merge_tile_pass.run (exprs);
701+ if (b)
702+ std::cout << " merge_tile_pass made changes" << std::endl;
703+ changed = changed || b;
605704 }
705+
706+ std::cout << " \n [INFER_ROLES DEBUG] Final inferred_dims:" << std::endl;
707+ for (const auto & dim : inferred_dims) {
708+ std::cout << " " << dim << std::endl;
709+ }
710+
606711 return {bulk_groups, nonbulk_groups, inferred_dims};
607712}
608713
@@ -882,7 +987,25 @@ class DomainMerger {
882987 auto type1 = type (i + 1 );
883988
884989 bool may_increasing_box_size = (type0 == CB && type1 == CB);
990+
991+ std::cout << " [shouldMerge] Checking merge of dim " << i << " and "
992+ << (i + 1 ) << std::endl;
993+ std::cout << " Types: "
994+ << (type0 == P ? " P"
995+ : type0 == C ? " C"
996+ : type0 == SB ? " SB"
997+ : " CB" )
998+ << " + "
999+ << (type1 == P ? " P"
1000+ : type1 == C ? " C"
1001+ : type1 == SB ? " SB"
1002+ : " CB" )
1003+ << std::endl;
1004+ std::cout << " May increase box size: "
1005+ << (may_increasing_box_size ? " yes" : " no" ) << std::endl;
1006+
8851007 if (!may_increasing_box_size) {
1008+ std::cout << " Decision: MERGE (not increasing box size)" << std::endl;
8861009 return true ;
8871010 }
8881011
@@ -891,6 +1014,11 @@ class DomainMerger {
8911014 Val* merged_extent = SimplifyingIrBuilder::mulExpr (extent0, extent1);
8921015
8931016 bool merging_innermost = ((int64_t )size () == i + 2 );
1017+ std::cout << " Extents: " << extent0->toString () << " * "
1018+ << extent1->toString () << " = " << merged_extent->toString ()
1019+ << std::endl;
1020+ std::cout << " Merging innermost: " << (merging_innermost ? " yes" : " no" )
1021+ << std::endl;
8941022
8951023 // If merging makes the size of a dimension larger than 256, we should not
8961024 // merge.
@@ -899,6 +1027,8 @@ class DomainMerger {
8991027 Val* too_large_after_merge = SimplifyingIrBuilder::gtExpr (
9001028 merged_extent, IrBuilder::create<Val>(largest_dim_size));
9011029 if (simplifyExpr (too_large_after_merge)->isTrue ()) {
1030+ std::cout << " Decision: DON'T MERGE (exceeds 256 element limit)"
1031+ << std::endl;
9021032 return false ;
9031033 }
9041034
@@ -907,16 +1037,22 @@ class DomainMerger {
9071037 if (merging_innermost && swizzle_ != MmaInputSmemSwizzle::None) {
9081038 const int64_t swizzle_size =
9091039 getBytesFromSwizzle (swizzle_) / item_size_bytes_;
1040+ std::cout << " Swizzle size check: merged="
1041+ << merged_extent->toString () << " , swizzle=" << swizzle_size
1042+ << std::endl;
9101043 Val* merging_makes_gt_swizzle_size = SimplifyingIrBuilder::gtExpr (
9111044 merged_extent, IrBuilder::create<Val>(swizzle_size));
9121045 if (simplifyExpr (merging_makes_gt_swizzle_size)->isTrue ()) {
1046+ std::cout << " Decision: DON'T MERGE (exceeds swizzle size)"
1047+ << std::endl;
9131048 return false ;
9141049 }
9151050 }
9161051
9171052 // Because the shape is dynamic, we don't know if we should merge or
9181053 // not. For this case, we always assume merging is better than not
9191054 // merging.
1055+ std::cout << " Decision: MERGE (default/dynamic case)" << std::endl;
9201056 return true ;
9211057 }
9221058
@@ -925,10 +1061,39 @@ class DomainMerger {
9251061 auto type1 = type (i + 1 );
9261062 auto g0 = (*this )[i];
9271063 auto g1 = (*this )[i + 1 ];
1064+
1065+ // DEBUG: Print what's being merged
1066+ std::cout << " [TMA DEBUG] Merging dimension " << i << " and " << (i + 1 )
1067+ << std::endl;
1068+ std::cout << " Type[" << i << " ] = "
1069+ << (type0 == P ? " P"
1070+ : type0 == C ? " C"
1071+ : type0 == SB ? " SB"
1072+ : " CB" )
1073+ << std::endl;
1074+ std::cout << " Type[" << (i + 1 ) << " ] = "
1075+ << (type1 == P ? " P"
1076+ : type1 == C ? " C"
1077+ : type1 == SB ? " SB"
1078+ : " CB" )
1079+ << std::endl;
1080+ std::cout << " Extent[" << i
1081+ << " ] = " << g0->front ()->as <IterDomain>()->extent ()->toString ()
1082+ << std::endl;
1083+ std::cout << " Extent[" << (i + 1 )
1084+ << " ] = " << g1->front ()->as <IterDomain>()->extent ()->toString ()
1085+ << std::endl;
1086+ std::cout << " Contiguity[" << i
1087+ << " ] = " << (contiguity (i) ? " true" : " false" ) << std::endl;
1088+
9281089 domain_.merge (i);
9291090 contiguity_and_stride_.erase (contiguity_and_stride_.begin () + i);
9301091 const auto & g = (*this )[i];
9311092
1093+ std::cout << " Merged extent = "
1094+ << g->front ()->as <IterDomain>()->extent ()->toString ()
1095+ << std::endl;
1096+
9321097 // Update bulk_groups_ and nonbulk_groups_ by propagating through the merge.
9331098 if (bulk_groups_.count (g0) > 0 && bulk_groups_.count (g1) > 0 ) {
9341099 bulk_groups_.insert (g);
@@ -1002,7 +1167,33 @@ std::vector<TMADim> run(
10021167 dim_info,
10031168 swizzle,
10041169 item_size_bytes);
1170+
1171+ // DEBUG: Print initial TMA domain state
1172+ std::cout
1173+ << " \n [TMA DEBUG] ===== Initial TMA Domain (innermost to outermost) ====="
1174+ << std::endl;
1175+ std::cout << " Total dimensions: " << tma_domain.size () << std::endl;
1176+ for (int64_t i = 0 ; i < (int64_t )tma_domain.size (); i++) {
1177+ auto t = tma_domain.type (i);
1178+ std::cout << " Dim[" << i << " ]: type="
1179+ << (t == P ? " P"
1180+ : t == C ? " C"
1181+ : t == SB ? " SB"
1182+ : " CB" )
1183+ << " , extent="
1184+ << tma_domain[i]->front ()->as <IterDomain>()->extent ()->toString ()
1185+ << " , contiguous="
1186+ << (tma_domain.contiguity (i) ? " true" : " false" ) << std::endl;
1187+ }
1188+ std::cout << " Swizzle size (items): "
1189+ << (swizzle != MmaInputSmemSwizzle::None
1190+ ? getBytesFromSwizzle (swizzle) / item_size_bytes
1191+ : 0 )
1192+ << std::endl;
1193+
10051194 // merge contiguous C groups and CB groups
1195+ std::cout << " \n [TMA DEBUG] ===== Phase 1: Merging C-C and CB-CB ====="
1196+ << std::endl;
10061197 for (int64_t i = 0 ; i < (int64_t )tma_domain.size () - 1 ; i++) {
10071198 if (!tma_domain.contiguity (i)) {
10081199 continue ;
@@ -1016,6 +1207,8 @@ std::vector<TMADim> run(
10161207 }
10171208 }
10181209 // merge contiguous C with SB/CB
1210+ std::cout << " \n [TMA DEBUG] ===== Phase 2: Merging C with SB/CB ====="
1211+ << std::endl;
10191212 for (int64_t i = 0 ; i < (int64_t )tma_domain.size () - 1 ; i++) {
10201213 if (!tma_domain.contiguity (i)) {
10211214 continue ;
@@ -1032,7 +1225,24 @@ std::vector<TMADim> run(
10321225 // Compute the final TMA domain. As required by the hardware, tensors used by
10331226 // TMA must be in column major, so our final TMA domain is also from innermost
10341227 // to outermost.
1228+ std::cout << " \n [TMA DEBUG] ===== After All Merges =====" << std::endl;
1229+ std::cout << " Total dimensions: " << tma_domain.size () << std::endl;
1230+ for (int64_t i = 0 ; i < (int64_t )tma_domain.size (); i++) {
1231+ auto t = tma_domain.type (i);
1232+ std::cout << " Dim[" << i << " ]: type="
1233+ << (t == P ? " P"
1234+ : t == C ? " C"
1235+ : t == SB ? " SB"
1236+ : " CB" )
1237+ << " , extent="
1238+ << tma_domain[i]->front ()->as <IterDomain>()->extent ()->toString ()
1239+ << std::endl;
1240+ }
1241+
10351242 std::vector<TMADim> result;
1243+ std::cout << " \n [TMA DEBUG] ===== Creating Final TMA Domain (innermost to "
1244+ " outermost) ====="
1245+ << std::endl;
10361246 for (int64_t i = (int64_t )tma_domain.size () - 1 ; i >= 0 ; i--) {
10371247 const auto & g = tma_domain[i];
10381248 result.emplace_back ();
@@ -1063,7 +1273,53 @@ std::vector<TMADim> run(
10631273 }
10641274 result.back ().gmem_stride_bytes =
10651275 SimplifyingIrBuilder::mulExpr (tma_domain.stride (i), item_size_bytes);
1276+
1277+ // DEBUG: Print info about this dimension
1278+ std::cout << " Final Dim[" << (result.size () - 1 ) << " ]:" << std::endl;
1279+ std::cout << " partitioned="
1280+ << (result.back ().partitioned ? result.back ()
1281+ .partitioned ->front ()
1282+ ->as <IterDomain>()
1283+ ->extent ()
1284+ ->toString ()
1285+ : " nullptr" )
1286+ << std::endl;
1287+ std::cout << " box="
1288+ << (result.back ().box ? result.back ()
1289+ .box ->front ()
1290+ ->as <IterDomain>()
1291+ ->extent ()
1292+ ->toString ()
1293+ : " nullptr (size-one)" )
1294+ << std::endl;
1295+ std::cout << " tile="
1296+ << (result.back ().tile ? result.back ()
1297+ .tile ->front ()
1298+ ->as <IterDomain>()
1299+ ->extent ()
1300+ ->toString ()
1301+ : " nullptr" )
1302+ << std::endl;
1303+ std::cout << " stride="
1304+ << (result.back ().stride ? result.back ()
1305+ .stride ->front ()
1306+ ->as <IterDomain>()
1307+ ->extent ()
1308+ ->toString ()
1309+ : " nullptr" )
1310+ << std::endl;
10661311 }
1312+
1313+ std::cout << " \n [TMA DEBUG] ===== Final Box Sizes =====" << std::endl;
1314+ for (size_t i = 0 ; i < result.size (); i++) {
1315+ std::cout
1316+ << " Box[" << i << " ] = "
1317+ << (result[i].box
1318+ ? result[i].box ->front ()->as <IterDomain>()->extent ()->toString ()
1319+ : " 1 (implicit)" )
1320+ << std::endl;
1321+ }
1322+
10671323 return result;
10681324}
10691325
@@ -1158,6 +1414,9 @@ Val* TMAInfo::tensorMap() const {
11581414 dims_.end (),
11591415 std::back_inserter (box_sizes_inner_to_outer),
11601416 [](const TMADim& d) { return d.boxSize (); });
1417+ for (auto box_size : box_sizes_inner_to_outer) {
1418+ std::cout << " box_size: " << box_size->toString () << std::endl;
1419+ }
11611420
11621421 std::vector<Val*> element_strides_inner_to_outer;
11631422 std::transform (
0 commit comments