@@ -208,32 +208,20 @@ struct PropagateDispatchSizeBoundsPass final
208208 }
209209
210210 // Compute the subgroup ID bound: max total threads / min subgroup size.
211+ std::optional<int64_t > maxFlatWorkgroupSize;
211212 std::optional<int64_t > subgroupIdBound;
212- if (minSubgroupSize) {
213- int64_t maxTotalThreads = 1 ;
214- bool allSizesKnown = true ;
215- for (std::optional<int64_t > size : workgroupSizes) {
216- if (size) {
217- maxTotalThreads *= *size;
218- // Cap at the hardware thread-per-workgroup limit inside the loop
219- // to avoid overflow from multiplying per-dimension maximums.
220- if (target) {
221- maxTotalThreads =
222- std::min (maxTotalThreads,
223- static_cast <int64_t >(
224- target.getWgp ().getMaxThreadCountPerWorkgroup ()));
225- }
226- } else {
227- allSizesKnown = false ;
228- break ;
229- }
230- }
231- if (!allSizesKnown && target) {
232- maxTotalThreads = target.getWgp ().getMaxThreadCountPerWorkgroup ();
233- }
234- if (allSizesKnown || target) {
235- subgroupIdBound = llvm::divideCeil (maxTotalThreads, *minSubgroupSize);
236- }
213+ if (staticWorkgroupSize) {
214+ maxFlatWorkgroupSize = llvm::product_of (*staticWorkgroupSize);
215+ }
216+ if (target) {
217+ maxFlatWorkgroupSize = std::min (
218+ maxFlatWorkgroupSize.value_or (std::numeric_limits<int64_t >::max ()),
219+ static_cast <int64_t >(
220+ target.getWgp ().getMaxThreadCountPerWorkgroup ()));
221+ }
222+ if (maxFlatWorkgroupSize && minSubgroupSize) {
223+ subgroupIdBound =
224+ llvm::divideCeil (*maxFlatWorkgroupSize, *minSubgroupSize);
237225 }
238226
239227 std::optional<int64_t > constantSubgroupSize;
0 commit comments