diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp index ef44a0ec68d9c..db6b472ff9733 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp @@ -24,10 +24,10 @@ using namespace mlir; namespace { -template -static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { - llvm::SmallVector> values; - for (auto operand : op.getDataClauseOperands()) { +static void collectPtrs(mlir::ValueRange operands, + llvm::SmallVector> &values, + bool hostToDevice) { + for (auto operand : operands) { Value varPtr = acc::getVarPtr(operand.getDefiningOp()); Value accPtr = acc::getAccPtr(operand.getDefiningOp()); if (varPtr && accPtr) { @@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { values.push_back({accPtr, varPtr}); } } +} + +template +static void collectAndReplaceInRegion(Op &op, bool hostToDevice) { + llvm::SmallVector> values; + + if constexpr (std::is_same_v) { + collectPtrs(op.getReductionOperands(), values, hostToDevice); + collectPtrs(op.getPrivateOperands(), values, hostToDevice); + } else { + collectPtrs(op.getDataClauseOperands(), values, hostToDevice); + if constexpr (!std::is_same_v) { + collectPtrs(op.getReductionOperands(), values, hostToDevice); + collectPtrs(op.getGangPrivateOperands(), values, hostToDevice); + collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice); + } + } for (auto p : values) replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion()); @@ -50,7 +67,7 @@ struct LegalizeDataInRegion bool replaceHostVsDevice = this->hostToDevice.getValue(); funcOp.walk([&](Operation *op) { - if (!isa(*op)) + if (!isa(*op) && !isa(*op)) return; if (auto parallelOp = dyn_cast(*op)) { @@ -59,6 +76,8 @@ struct LegalizeDataInRegion collectAndReplaceInRegion(serialOp, replaceHostVsDevice); } else if (auto kernelsOp = dyn_cast(*op)) { collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice); + } else if (auto loopOp = dyn_cast(*op)) { + collectAndReplaceInRegion(loopOp, replaceHostVsDevice); } }); } diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir index 4c86223c720a3..113fe90450ab7 100644 --- a/mlir/test/Dialect/OpenACC/legalize-data.mlir +++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir @@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) { // CHECK: } // CHECK: acc.yield // CHECK: } + +// ----- + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +func.func @test(%a: memref<10xf32>) { + %lb = arith.constant 0 : index + %st = arith.constant 1 : index + %c10 = arith.constant 10 : index + %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) { + acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { + %ci = memref.load %a[%i] : memref<10xf32> + acc.yield + } + acc.yield + } + return +} + +// CHECK: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>) +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { +// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { +// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> +// CHECK: acc.yield +// CHECK: } +// CHECK: acc.yield +// CHECK: } + +// ----- + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +func.func @test(%a: memref<10xf32>) { + %lb = arith.constant 0 : index + %st = arith.constant 1 : index + %c10 = arith.constant 10 : index + %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.parallel { + acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { + %ci = memref.load %a[%i] : memref<10xf32> + acc.yield + } + acc.yield + } + return +} + +// CHECK: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>) +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.parallel { +// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { +// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> +// CHECK: acc.yield +// CHECK: } +// CHECK: acc.yield +// CHECK: } + +// ----- + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +func.func @test(%a: memref<10xf32>) { + %lb = arith.constant 0 : index + %st = arith.constant 1 : index + %c10 = arith.constant 10 : index + %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32> + acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) { + acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) { + %ci = memref.load %a[%i] : memref<10xf32> + acc.yield + } + acc.yield + } + return +} + +// CHECK: func.func @test +// CHECK-SAME: (%[[A:.*]]: memref<10xf32>) +// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32> +// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) { +// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) { +// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32> +// CHECK: acc.yield +// CHECK: } +// CHECK: acc.yield +// CHECK: }