Skip to content

Commit 5868dd7

Browse files
committed
check 256 size limitation in MergeTileGroupsByRotation
1 parent f24eea0 commit 5868dd7

File tree

1 file changed

+265
-6
lines changed

1 file changed

+265
-6
lines changed

csrc/device_lower/analysis/tma.cpp

Lines changed: 265 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,13 +522,62 @@ class MergeTileGroupsByRotation : public Pass {
522522
if (partitioned_dim == inferred_dims_.end()) {
523523
return false;
524524
}
525+
526+
bool box_is_bulk = bulk_groups_.count(partitioned_dim->box) > 0;
527+
bool from1_is_bulk = bulk_groups_.count(from_.at(1)) > 0;
528+
529+
std::cout << "[MergeTileGroupsByRotation::condition] Checking expr: "
530+
<< expr->front()->toString() << std::endl;
531+
std::cout << " from[0] (partitioned): " << from_.at(0)->toString()
532+
<< " extent="
533+
<< from_.at(0)->front()->as<IterDomain>()->extent()->toString()
534+
<< std::endl;
535+
std::cout << " from[1]: " << from_.at(1)->toString() << " extent="
536+
<< from_.at(1)->front()->as<IterDomain>()->extent()->toString()
537+
<< std::endl;
538+
std::cout << " to[0]: " << to_.at(0)->toString() << " extent="
539+
<< to_.at(0)->front()->as<IterDomain>()->extent()->toString()
540+
<< std::endl;
541+
std::cout
542+
<< " partitioned_dim->box extent: "
543+
<< partitioned_dim->box->front()->as<IterDomain>()->extent()->toString()
544+
<< std::endl;
545+
std::cout << " partitioned_dim->box is bulk: " << box_is_bulk << std::endl;
546+
std::cout << " from[1] is bulk: " << from1_is_bulk << std::endl;
547+
525548
if (bulk_groups_.count(partitioned_dim->box) == 0 ||
526549
bulk_groups_.count(from_.at(1)) == 0) {
550+
std::cout << " Result: FALSE (not both bulk)" << std::endl;
527551
return false;
528552
}
529553
NVF_ERROR(
530554
partitioned_dim->tile == partitioned_dim->box &&
531555
partitioned_dim->stride == nullptr);
556+
557+
// Check if merging would exceed the 256-element hardware limit
558+
auto box_extent = partitioned_dim->box->front()->as<IterDomain>()->extent();
559+
auto from1_extent = from_.at(1)->front()->as<IterDomain>()->extent();
560+
Val* merged_extent =
561+
SimplifyingIrBuilder::mulExpr(box_extent, from1_extent);
562+
563+
constexpr int64_t largest_dim_size = 256; // Hardware limitation
564+
Val* too_large_after_merge = SimplifyingIrBuilder::gtExpr(
565+
merged_extent, IrBuilder::create<Val>(largest_dim_size));
566+
567+
std::cout << " Merged box size would be: " << box_extent->toString()
568+
<< " * " << from1_extent->toString() << " = "
569+
<< merged_extent->toString() << std::endl;
570+
571+
if (simplifyExpr(too_large_after_merge)->isTrue()) {
572+
std::cout << " Result: FALSE (would exceed 256 element limit)"
573+
<< std::endl;
574+
return false;
575+
}
576+
577+
std::cout
578+
<< " Result: TRUE - WILL MERGE (pattern matches and within limits!)"
579+
<< std::endl;
580+
532581
return true;
533582
}
534583

@@ -540,7 +589,23 @@ class MergeTileGroupsByRotation : public Pass {
540589
inferred_dims_.begin(), inferred_dims_.end(), [&](const auto& dim) {
541590
return dim.partitioned == from_.at(0);
542591
});
592+
593+
std::cout << "[MergeTileGroupsByRotation] Merging tile groups:"
594+
<< std::endl;
595+
std::cout
596+
<< " Old box extent: "
597+
<< partitioned_dim->box->front()->as<IterDomain>()->extent()->toString()
598+
<< std::endl;
599+
std::cout << " Merging with: "
600+
<< from_.at(1)->front()->as<IterDomain>()->extent()->toString()
601+
<< std::endl;
602+
543603
auto new_box = merge(&id_graph, partitioned_dim->tile, from_.at(1));
604+
605+
std::cout << " New box extent: "
606+
<< new_box->front()->as<IterDomain>()->extent()->toString()
607+
<< std::endl;
608+
544609
bulk_groups_.insert(new_box);
545610
inferred_dims_.emplace_back();
546611
inferred_dims_.back().partitioned = to_.at(0);
@@ -593,16 +658,56 @@ run(
593658
nonbulk_groups, inferred_dims);
594659
MergeTileGroupsByRotation merge_tile_pass(bulk_groups, inferred_dims);
595660

661+
std::cout << "\n[INFER_ROLES DEBUG] Starting infer_roles passes" << std::endl;
662+
int iteration = 0;
596663
bool changed = true;
597664
while (changed) {
598665
changed = false;
599-
changed = changed || bulk_pass.run(exprs);
600-
changed = changed || nonbulk_pass.run(exprs);
601-
changed = changed || striding_split_pass.run(exprs);
602-
changed = changed || boxing_split_pass.run(exprs);
603-
changed = changed || move_partitioned_pass.run(exprs);
604-
changed = changed || merge_tile_pass.run(exprs);
666+
std::cout << "\n[INFER_ROLES DEBUG] === Iteration " << iteration++
667+
<< " ===" << std::endl;
668+
std::cout << " Remaining exprs: " << exprs.size() << std::endl;
669+
std::cout << " Current inferred_dims: " << inferred_dims.size()
670+
<< std::endl;
671+
for (const auto& dim : inferred_dims) {
672+
std::cout << " " << dim << std::endl;
673+
}
674+
675+
bool b = bulk_pass.run(exprs);
676+
if (b)
677+
std::cout << " bulk_pass made changes" << std::endl;
678+
changed = changed || b;
679+
680+
b = nonbulk_pass.run(exprs);
681+
if (b)
682+
std::cout << " nonbulk_pass made changes" << std::endl;
683+
changed = changed || b;
684+
685+
b = striding_split_pass.run(exprs);
686+
if (b)
687+
std::cout << " striding_split_pass made changes" << std::endl;
688+
changed = changed || b;
689+
690+
b = boxing_split_pass.run(exprs);
691+
if (b)
692+
std::cout << " boxing_split_pass made changes" << std::endl;
693+
changed = changed || b;
694+
695+
b = move_partitioned_pass.run(exprs);
696+
if (b)
697+
std::cout << " move_partitioned_pass made changes" << std::endl;
698+
changed = changed || b;
699+
700+
b = merge_tile_pass.run(exprs);
701+
if (b)
702+
std::cout << " merge_tile_pass made changes" << std::endl;
703+
changed = changed || b;
605704
}
705+
706+
std::cout << "\n[INFER_ROLES DEBUG] Final inferred_dims:" << std::endl;
707+
for (const auto& dim : inferred_dims) {
708+
std::cout << " " << dim << std::endl;
709+
}
710+
606711
return {bulk_groups, nonbulk_groups, inferred_dims};
607712
}
608713

@@ -882,7 +987,25 @@ class DomainMerger {
882987
auto type1 = type(i + 1);
883988

884989
bool may_increasing_box_size = (type0 == CB && type1 == CB);
990+
991+
std::cout << " [shouldMerge] Checking merge of dim " << i << " and "
992+
<< (i + 1) << std::endl;
993+
std::cout << " Types: "
994+
<< (type0 == P ? "P"
995+
: type0 == C ? "C"
996+
: type0 == SB ? "SB"
997+
: "CB")
998+
<< " + "
999+
<< (type1 == P ? "P"
1000+
: type1 == C ? "C"
1001+
: type1 == SB ? "SB"
1002+
: "CB")
1003+
<< std::endl;
1004+
std::cout << " May increase box size: "
1005+
<< (may_increasing_box_size ? "yes" : "no") << std::endl;
1006+
8851007
if (!may_increasing_box_size) {
1008+
std::cout << " Decision: MERGE (not increasing box size)" << std::endl;
8861009
return true;
8871010
}
8881011

@@ -891,6 +1014,11 @@ class DomainMerger {
8911014
Val* merged_extent = SimplifyingIrBuilder::mulExpr(extent0, extent1);
8921015

8931016
bool merging_innermost = ((int64_t)size() == i + 2);
1017+
std::cout << " Extents: " << extent0->toString() << " * "
1018+
<< extent1->toString() << " = " << merged_extent->toString()
1019+
<< std::endl;
1020+
std::cout << " Merging innermost: " << (merging_innermost ? "yes" : "no")
1021+
<< std::endl;
8941022

8951023
// If merging makes the size of a dimension larger than 256, we should not
8961024
// merge.
@@ -899,6 +1027,8 @@ class DomainMerger {
8991027
Val* too_large_after_merge = SimplifyingIrBuilder::gtExpr(
9001028
merged_extent, IrBuilder::create<Val>(largest_dim_size));
9011029
if (simplifyExpr(too_large_after_merge)->isTrue()) {
1030+
std::cout << " Decision: DON'T MERGE (exceeds 256 element limit)"
1031+
<< std::endl;
9021032
return false;
9031033
}
9041034

@@ -907,16 +1037,22 @@ class DomainMerger {
9071037
if (merging_innermost && swizzle_ != MmaInputSmemSwizzle::None) {
9081038
const int64_t swizzle_size =
9091039
getBytesFromSwizzle(swizzle_) / item_size_bytes_;
1040+
std::cout << " Swizzle size check: merged="
1041+
<< merged_extent->toString() << ", swizzle=" << swizzle_size
1042+
<< std::endl;
9101043
Val* merging_makes_gt_swizzle_size = SimplifyingIrBuilder::gtExpr(
9111044
merged_extent, IrBuilder::create<Val>(swizzle_size));
9121045
if (simplifyExpr(merging_makes_gt_swizzle_size)->isTrue()) {
1046+
std::cout << " Decision: DON'T MERGE (exceeds swizzle size)"
1047+
<< std::endl;
9131048
return false;
9141049
}
9151050
}
9161051

9171052
// Because the shape is dynamic, we don't know if we should merge or
9181053
// not. For this case, we always assume merging is better than not
9191054
// merging.
1055+
std::cout << " Decision: MERGE (default/dynamic case)" << std::endl;
9201056
return true;
9211057
}
9221058

@@ -925,10 +1061,39 @@ class DomainMerger {
9251061
auto type1 = type(i + 1);
9261062
auto g0 = (*this)[i];
9271063
auto g1 = (*this)[i + 1];
1064+
1065+
// DEBUG: Print what's being merged
1066+
std::cout << "[TMA DEBUG] Merging dimension " << i << " and " << (i + 1)
1067+
<< std::endl;
1068+
std::cout << " Type[" << i << "] = "
1069+
<< (type0 == P ? "P"
1070+
: type0 == C ? "C"
1071+
: type0 == SB ? "SB"
1072+
: "CB")
1073+
<< std::endl;
1074+
std::cout << " Type[" << (i + 1) << "] = "
1075+
<< (type1 == P ? "P"
1076+
: type1 == C ? "C"
1077+
: type1 == SB ? "SB"
1078+
: "CB")
1079+
<< std::endl;
1080+
std::cout << " Extent[" << i
1081+
<< "] = " << g0->front()->as<IterDomain>()->extent()->toString()
1082+
<< std::endl;
1083+
std::cout << " Extent[" << (i + 1)
1084+
<< "] = " << g1->front()->as<IterDomain>()->extent()->toString()
1085+
<< std::endl;
1086+
std::cout << " Contiguity[" << i
1087+
<< "] = " << (contiguity(i) ? "true" : "false") << std::endl;
1088+
9281089
domain_.merge(i);
9291090
contiguity_and_stride_.erase(contiguity_and_stride_.begin() + i);
9301091
const auto& g = (*this)[i];
9311092

1093+
std::cout << " Merged extent = "
1094+
<< g->front()->as<IterDomain>()->extent()->toString()
1095+
<< std::endl;
1096+
9321097
// Update bulk_groups_ and nonbulk_groups_ by propagating through the merge.
9331098
if (bulk_groups_.count(g0) > 0 && bulk_groups_.count(g1) > 0) {
9341099
bulk_groups_.insert(g);
@@ -1002,7 +1167,33 @@ std::vector<TMADim> run(
10021167
dim_info,
10031168
swizzle,
10041169
item_size_bytes);
1170+
1171+
// DEBUG: Print initial TMA domain state
1172+
std::cout
1173+
<< "\n[TMA DEBUG] ===== Initial TMA Domain (innermost to outermost) ====="
1174+
<< std::endl;
1175+
std::cout << " Total dimensions: " << tma_domain.size() << std::endl;
1176+
for (int64_t i = 0; i < (int64_t)tma_domain.size(); i++) {
1177+
auto t = tma_domain.type(i);
1178+
std::cout << " Dim[" << i << "]: type="
1179+
<< (t == P ? "P"
1180+
: t == C ? "C"
1181+
: t == SB ? "SB"
1182+
: "CB")
1183+
<< ", extent="
1184+
<< tma_domain[i]->front()->as<IterDomain>()->extent()->toString()
1185+
<< ", contiguous="
1186+
<< (tma_domain.contiguity(i) ? "true" : "false") << std::endl;
1187+
}
1188+
std::cout << " Swizzle size (items): "
1189+
<< (swizzle != MmaInputSmemSwizzle::None
1190+
? getBytesFromSwizzle(swizzle) / item_size_bytes
1191+
: 0)
1192+
<< std::endl;
1193+
10051194
// merge contiguous C groups and CB groups
1195+
std::cout << "\n[TMA DEBUG] ===== Phase 1: Merging C-C and CB-CB ====="
1196+
<< std::endl;
10061197
for (int64_t i = 0; i < (int64_t)tma_domain.size() - 1; i++) {
10071198
if (!tma_domain.contiguity(i)) {
10081199
continue;
@@ -1016,6 +1207,8 @@ std::vector<TMADim> run(
10161207
}
10171208
}
10181209
// merge contiguous C with SB/CB
1210+
std::cout << "\n[TMA DEBUG] ===== Phase 2: Merging C with SB/CB ====="
1211+
<< std::endl;
10191212
for (int64_t i = 0; i < (int64_t)tma_domain.size() - 1; i++) {
10201213
if (!tma_domain.contiguity(i)) {
10211214
continue;
@@ -1032,7 +1225,24 @@ std::vector<TMADim> run(
10321225
// Compute the final TMA domain. As required by the hardware, tensors used by
10331226
// TMA must be in column major, so our final TMA domain is also from innermost
10341227
// to outermost.
1228+
std::cout << "\n[TMA DEBUG] ===== After All Merges =====" << std::endl;
1229+
std::cout << " Total dimensions: " << tma_domain.size() << std::endl;
1230+
for (int64_t i = 0; i < (int64_t)tma_domain.size(); i++) {
1231+
auto t = tma_domain.type(i);
1232+
std::cout << " Dim[" << i << "]: type="
1233+
<< (t == P ? "P"
1234+
: t == C ? "C"
1235+
: t == SB ? "SB"
1236+
: "CB")
1237+
<< ", extent="
1238+
<< tma_domain[i]->front()->as<IterDomain>()->extent()->toString()
1239+
<< std::endl;
1240+
}
1241+
10351242
std::vector<TMADim> result;
1243+
std::cout << "\n[TMA DEBUG] ===== Creating Final TMA Domain (innermost to "
1244+
"outermost) ====="
1245+
<< std::endl;
10361246
for (int64_t i = (int64_t)tma_domain.size() - 1; i >= 0; i--) {
10371247
const auto& g = tma_domain[i];
10381248
result.emplace_back();
@@ -1063,7 +1273,53 @@ std::vector<TMADim> run(
10631273
}
10641274
result.back().gmem_stride_bytes =
10651275
SimplifyingIrBuilder::mulExpr(tma_domain.stride(i), item_size_bytes);
1276+
1277+
// DEBUG: Print info about this dimension
1278+
std::cout << " Final Dim[" << (result.size() - 1) << "]:" << std::endl;
1279+
std::cout << " partitioned="
1280+
<< (result.back().partitioned ? result.back()
1281+
.partitioned->front()
1282+
->as<IterDomain>()
1283+
->extent()
1284+
->toString()
1285+
: "nullptr")
1286+
<< std::endl;
1287+
std::cout << " box="
1288+
<< (result.back().box ? result.back()
1289+
.box->front()
1290+
->as<IterDomain>()
1291+
->extent()
1292+
->toString()
1293+
: "nullptr (size-one)")
1294+
<< std::endl;
1295+
std::cout << " tile="
1296+
<< (result.back().tile ? result.back()
1297+
.tile->front()
1298+
->as<IterDomain>()
1299+
->extent()
1300+
->toString()
1301+
: "nullptr")
1302+
<< std::endl;
1303+
std::cout << " stride="
1304+
<< (result.back().stride ? result.back()
1305+
.stride->front()
1306+
->as<IterDomain>()
1307+
->extent()
1308+
->toString()
1309+
: "nullptr")
1310+
<< std::endl;
10661311
}
1312+
1313+
std::cout << "\n[TMA DEBUG] ===== Final Box Sizes =====" << std::endl;
1314+
for (size_t i = 0; i < result.size(); i++) {
1315+
std::cout
1316+
<< " Box[" << i << "] = "
1317+
<< (result[i].box
1318+
? result[i].box->front()->as<IterDomain>()->extent()->toString()
1319+
: "1 (implicit)")
1320+
<< std::endl;
1321+
}
1322+
10671323
return result;
10681324
}
10691325

@@ -1158,6 +1414,9 @@ Val* TMAInfo::tensorMap() const {
11581414
dims_.end(),
11591415
std::back_inserter(box_sizes_inner_to_outer),
11601416
[](const TMADim& d) { return d.boxSize(); });
1417+
for (auto box_size : box_sizes_inner_to_outer) {
1418+
std::cout << "box_size: " << box_size->toString() << std::endl;
1419+
}
11611420

11621421
std::vector<Val*> element_strides_inner_to_outer;
11631422
std::transform(

0 commit comments

Comments
 (0)