diff --git a/compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp b/compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp index a3267d8b2cea..5547a67c8257 100644 --- a/compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp @@ -8,6 +8,7 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Transforms/Passes.h" @@ -25,7 +26,7 @@ static void foldConstantBounds( FunctionOpInterface funcOp, const std::optional> &staticWorkgroupSizes, ArrayRef staticWorkgroupCounts, - std::optional subgroupSize) { + std::optional subgroupSize) { IRRewriter rewriter(funcOp->getContext()); auto rewriteToConstant = [&](Operation *op, int64_t constant) { rewriter.setInsertionPoint(op); @@ -70,13 +71,24 @@ static void foldConstantBounds( static void applyBounds(FunctionOpInterface funcOp, ArrayRef> workgroupSizes, ArrayRef> workgroupCounts, - std::optional subgroupSize) { + std::optional maxSubgroupSize, + std::optional subgroupIdBound) { Builder b(funcOp->getContext()); funcOp->walk([&](Operation *op) { TypeSwitch(op) .Case([&](gpu::LaneIdOp laneIdOp) { - if (subgroupSize) { - laneIdOp.setUpperBoundAttr(b.getIndexAttr(*subgroupSize)); + if (maxSubgroupSize) { + laneIdOp.setUpperBoundAttr(b.getIndexAttr(*maxSubgroupSize)); + } + }) + .Case([&](gpu::SubgroupSizeOp subgroupSizeOp) { + if (maxSubgroupSize) { + subgroupSizeOp.setUpperBoundAttr(b.getIndexAttr(*maxSubgroupSize)); + } + }) + .Case([&](gpu::SubgroupIdOp subgroupIdOp) { + if (subgroupIdBound) { + subgroupIdOp.setUpperBoundAttr(b.getIndexAttr(*subgroupIdBound)); } }) .Case([&](gpu::ThreadIdOp tidOp) { @@ -143,7 +155,9 @@ struct PropagateDispatchSizeBoundsPass final std::optional> staticWorkgroupSize = getWorkgroupSize(funcOp); - std::optional subgroupSize = getGPUSubgroupSize(funcOp); + // Check if a specific subgroup size has been explicitly chosen via the + // codegen pipeline configuration. + std::optional staticSubgroupSize = getSubgroupSize(funcOp); // Late in codegen, we've reconciled the workgroup size onto the export op. if (std::optional exportOp = @@ -157,7 +171,25 @@ struct PropagateDispatchSizeBoundsPass final if (std::optional exportSubgroupSize = exportOp->getSubgroupSizeAsUInt()) { - subgroupSize = exportSubgroupSize; + staticSubgroupSize = static_cast(*exportSubgroupSize); + } + } + + // Determine min and max subgroup size bounds. When a specific subgroup + // size has been picked, min == max == that size. Otherwise, use the + // range from the GPU target's WGP info. + std::optional minSubgroupSize; + std::optional maxSubgroupSize; + if (staticSubgroupSize) { + minSubgroupSize = maxSubgroupSize = staticSubgroupSize; + } else if (target) { + assert(!target.getWgp().getSubgroupSizeChoices().empty() && + "GPU target must have at least one subgroup size choice"); + minSubgroupSize = target.getMinSubgroupSize(); + maxSubgroupSize = target.getMaxSubgroupSize(); + if (*minSubgroupSize == *maxSubgroupSize) { + // There's only one option, so we know what it is. + staticSubgroupSize = maxSubgroupSize; } } @@ -179,9 +211,27 @@ struct PropagateDispatchSizeBoundsPass final } } + // Compute the subgroup ID bound: max total threads / min subgroup size. + std::optional maxFlatWorkgroupSize; + std::optional subgroupIdBound; + if (staticWorkgroupSize) { + maxFlatWorkgroupSize = llvm::product_of(*staticWorkgroupSize); + } + if (target) { + maxFlatWorkgroupSize = std::min( + maxFlatWorkgroupSize.value_or(std::numeric_limits::max()), + static_cast( + target.getWgp().getMaxThreadCountPerWorkgroup())); + } + if (maxFlatWorkgroupSize && minSubgroupSize) { + subgroupIdBound = + llvm::divideCeil(*maxFlatWorkgroupSize, *minSubgroupSize); + } + foldConstantBounds(funcOp, staticWorkgroupSize, staticWorkgroupCounts, - subgroupSize); - applyBounds(funcOp, workgroupSizes, workgroupCounts, subgroupSize); + staticSubgroupSize); + applyBounds(funcOp, workgroupSizes, workgroupCounts, maxSubgroupSize, + subgroupIdBound); } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir b/compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir index 53ea67a81a20..0230df3d0319 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir @@ -4,9 +4,9 @@ // Note: not the real target definition, missing types #executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target>}> +#pipeline_layout = #hal.pipeline.layout]> + +hal.executable private @gfx1100_variable_subgroup { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target) { + hal.executable.export public @gfx1100_variable_subgroup ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + hal.return %c128, %c1, %c1 : index, index, index + } attributes {workgroup_size = [128 : index, 1 : index, 1 : index]} + builtin.module { +// CHECK-LABEL: func.func @gfx1100_variable_subgroup() + func.func @gfx1100_variable_subgroup() { // CHECK-NEXT: gpu.lane_id upper_bound 64 %lane_id = gpu.lane_id +// CHECK-NEXT: gpu.subgroup_id upper_bound 4 : index + %subgroup_id = gpu.subgroup_id : index + +// CHECK-NEXT: gpu.subgroup_size upper_bound 64 : index + %subgroup_size = gpu.subgroup_size : index + + return + } + } + } +} + +// ----- + +// Test pseudo-variable subgroup sizes on gfx942 (subgroup_size_choices = [64]) +// with static workgroup sizes but no explicit subgroup_size selection in case +// that ever comes up. +#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", + {iree_codegen.target_info = #iree_gpu.target>}> +#pipeline_layout = #hal.pipeline.layout]> + +hal.executable private @gfx942_not_really_variable_subgroup { + hal.executable.variant public @rocm_hsaco_fb target(#executable_target) { + hal.executable.export public @gfx942_not_really_variable_subgroup ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + hal.return %c128, %c1, %c1 : index, index, index + } attributes {workgroup_size = [128 : index, 1 : index, 1 : index]} + builtin.module { +// CHECK-LABEL: func.func @gfx942_not_really_variable_subgroup() + func.func @gfx942_not_really_variable_subgroup() { +// CHECK-NEXT: gpu.lane_id upper_bound 64 + %lane_id = gpu.lane_id + +// CHECK-NEXT: gpu.subgroup_id upper_bound 2 : index + %subgroup_id = gpu.subgroup_id : index + // CHECK-NEXT: arith.constant 64 : index %subgroup_size = gpu.subgroup_size : index @@ -110,8 +200,8 @@ hal.executable private @manual_subgroup_size { #executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.target_info = #iree_gpu.target