Skip to content

Commit 8adfb9f

Browse files
committed
Simplify code per review comments
1 parent 5ac45db commit 8adfb9f

1 file changed

Lines changed: 13 additions & 25 deletions

File tree

compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -208,32 +208,20 @@ struct PropagateDispatchSizeBoundsPass final
208208
}
209209

210210
// Compute the subgroup ID bound: max total threads / min subgroup size.
211+
std::optional<int64_t> maxFlatWorkgroupSize;
211212
std::optional<int64_t> subgroupIdBound;
212-
if (minSubgroupSize) {
213-
int64_t maxTotalThreads = 1;
214-
bool allSizesKnown = true;
215-
for (std::optional<int64_t> size : workgroupSizes) {
216-
if (size) {
217-
maxTotalThreads *= *size;
218-
// Cap at the hardware thread-per-workgroup limit inside the loop
219-
// to avoid overflow from multiplying per-dimension maximums.
220-
if (target) {
221-
maxTotalThreads =
222-
std::min(maxTotalThreads,
223-
static_cast<int64_t>(
224-
target.getWgp().getMaxThreadCountPerWorkgroup()));
225-
}
226-
} else {
227-
allSizesKnown = false;
228-
break;
229-
}
230-
}
231-
if (!allSizesKnown && target) {
232-
maxTotalThreads = target.getWgp().getMaxThreadCountPerWorkgroup();
233-
}
234-
if (allSizesKnown || target) {
235-
subgroupIdBound = llvm::divideCeil(maxTotalThreads, *minSubgroupSize);
236-
}
213+
if (staticWorkgroupSize) {
214+
maxFlatWorkgroupSize = llvm::product_of(*staticWorkgroupSize);
215+
}
216+
if (target) {
217+
maxFlatWorkgroupSize = std::min(
218+
maxFlatWorkgroupSize.value_or(std::numeric_limits<int64_t>::max()),
219+
static_cast<int64_t>(
220+
target.getWgp().getMaxThreadCountPerWorkgroup()));
221+
}
222+
if (maxFlatWorkgroupSize && minSubgroupSize) {
223+
subgroupIdBound =
224+
llvm::divideCeil(*maxFlatWorkgroupSize, *minSubgroupSize);
237225
}
238226

239227
std::optional<int64_t> constantSubgroupSize;

0 commit comments

Comments
 (0)