Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
def SCFToEmitC : Pass<"convert-scf-to-emitc"> {
let summary = "Convert SCF dialect to EmitC dialect, maintaining structured"
" control flow";
let dependentDialects = ["emitc::EmitCDialect"];
let dependentDialects = ["emitc::EmitCDialect", "memref::MemRefDialect"];
}

//===----------------------------------------------------------------------===//
Expand Down
52 changes: 40 additions & 12 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -63,21 +65,41 @@ static SmallVector<Value> createVariablesForResults(T op,

for (OpResult result : op.getResults()) {
Type resultType = result.getType();
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
SmallVector<OpFoldResult> dimensions = {rewriter.getIndexAttr(1)};
memref::AllocaOp var =
rewriter.create<memref::AllocaOp>(loc, dimensions, resultType);
resultVariables.push_back(var);
}

return resultVariables;
}

// Create a series of assign ops assigning given values to given variables at
// Create a series of load ops reading the values of given variables at
// the current insertion point of given rewriter.
static SmallVector<Value> readValues(SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
Value zero;
if (variables.size() > 0)
zero = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
SmallVector<Value> values;
SmallVector<Value> indices = {zero};
for (Value var : variables)
values.push_back(
rewriter.create<memref::LoadOp>(loc, var, indices).getResult());
return values;
}

// Create a series of store ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
Value zero;
if (variables.size() > 0)
zero = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
for (auto [value, var] : llvm::zip(values, variables)) {
SmallVector<Value> indices = {zero};
rewriter.create<memref::StoreOp>(loc, value, var, indices);
}
}

static void lowerYield(SmallVector<Value> &resultVariables,
Expand All @@ -100,8 +122,6 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables =
createVariablesForResults(forOp, rewriter);
SmallVector<Value> iterArgsVariables =
createVariablesForResults(forOp, rewriter);

Expand All @@ -115,18 +135,25 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// Erase the auto-generated terminator for the lowered for op.
rewriter.eraseOp(loweredBody->getTerminator());

IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(loweredBody);
SmallVector<Value> iterArgsValues =
readValues(iterArgsVariables, rewriter, loc);
rewriter.restoreInsertionPoint(ip);

SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());

rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
lowerYield(iterArgsVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

// Copy iterArgs into results after the for loop.
assignValues(iterArgsVariables, resultVariables, rewriter, loc);
SmallVector<Value> resultValues =
readValues(iterArgsVariables, rewriter, loc);

rewriter.replaceOp(forOp, resultVariables);
rewriter.replaceOp(forOp, resultValues);
return success();
}

Expand Down Expand Up @@ -169,6 +196,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,

auto loweredIf =
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
SmallVector<Value> resultValues = readValues(resultVariables, rewriter, loc);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);
Expand All @@ -178,7 +206,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
lowerRegion(elseRegion, loweredElseRegion);
}

rewriter.replaceOp(ifOp, resultVariables);
rewriter.replaceOp(ifOp, resultValues);
return success();
}

Expand Down
64 changes: 38 additions & 26 deletions mlir/test/Conversion/SCFToEmitC/for.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,24 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32)
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) {
// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_7]] : f32
// CHECK-NEXT: emitc.assign %[[VAL_4]] : f32 to %[[VAL_8]] : f32
// CHECK-NEXT: emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : f32
// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
// CHECK-NEXT: %[[VAL_5:.*]] = memref.alloca() : memref<1xf32>
// CHECK-NEXT: %[[VAL_6:.*]] = memref.alloca() : memref<1xf32>
// CHECK-NEXT: %[[VAL_7:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_3]], %[[VAL_5]]{{\[}}%[[VAL_7]]] : memref<1xf32>
// CHECK-NEXT: memref.store %[[VAL_4]], %[[VAL_6]]{{\[}}%[[VAL_7]]] : memref<1xf32>
// CHECK-NEXT: emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
// CHECK-NEXT: %[[VAL_9:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<1xf32>
// CHECK-NEXT: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref<1xf32>
// CHECK-NEXT: %[[VAL_12:.*]] = arith.addf %[[VAL_10]], %[[VAL_11]] : f32
// CHECK-NEXT: %[[VAL_13:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_12]], %[[VAL_5]]{{\[}}%[[VAL_13]]] : memref<1xf32>
// CHECK-NEXT: memref.store %[[VAL_12]], %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
// CHECK-NEXT: emitc.assign %[[VAL_8]] : f32 to %[[VAL_6]] : f32
// CHECK-NEXT: return %[[VAL_5]], %[[VAL_6]] : f32, f32
// CHECK-NEXT: %[[VAL_14:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[VAL_15:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_14]]] : memref<1xf32>
// CHECK-NEXT: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref<1xf32>
// CHECK-NEXT: return %[[VAL_15]], %[[VAL_16]] : f32, f32
// CHECK-NEXT: }

func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
Expand All @@ -77,20 +81,28 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
// CHECK-LABEL: func.func @nested_for_yield(
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 {
// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : f32
// CHECK-NEXT: %[[VAL_4:.*]] = memref.alloca() : memref<1xf32>
// CHECK-NEXT: %[[VAL_5:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_3]], %[[VAL_4]]{{\[}}%[[VAL_5]]] : memref<1xf32>
// CHECK-NEXT: emitc.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
// CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK-NEXT: emitc.assign %[[VAL_5]] : f32 to %[[VAL_8]] : f32
// CHECK-NEXT: emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_8]] : f32
// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
// CHECK-NEXT: %[[VAL_7:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_7]]] : memref<1xf32>
// CHECK-NEXT: %[[VAL_9:.*]] = memref.alloca() : memref<1xf32>
// CHECK-NEXT: %[[VAL_10:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_8]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<1xf32>
// CHECK-NEXT: emitc.for %[[VAL_11:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
// CHECK-NEXT: %[[VAL_12:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<1xf32>
// CHECK-NEXT: %[[VAL_14:.*]] = arith.addf %[[VAL_13]], %[[VAL_13]] : f32
// CHECK-NEXT: %[[VAL_15:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_14]], %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: emitc.assign %[[VAL_8]] : f32 to %[[VAL_7]] : f32
// CHECK-NEXT: emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
// CHECK-NEXT: %[[VAL_16:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<1xf32>
// CHECK-NEXT: %[[VAL_18:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_17]], %[[VAL_4]]{{\[}}%[[VAL_18]]] : memref<1xf32>
// CHECK-NEXT: }
// CHECK-NEXT: emitc.assign %[[VAL_5]] : f32 to %[[VAL_4]] : f32
// CHECK-NEXT: return %[[VAL_4]] : f32
// CHECK-NEXT: %[[VAL_19:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[VAL_20:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_19]]] : memref<1xf32>
// CHECK-NEXT: return %[[VAL_20]] : f32
// CHECK-NEXT: }
21 changes: 13 additions & 8 deletions mlir/test/Conversion/SCFToEmitC/if.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,23 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
// CHECK-SAME: %[[VAL_0:.*]]: i1,
// CHECK-SAME: %[[VAL_1:.*]]: f32) {
// CHECK-NEXT: %[[VAL_2:.*]] = arith.constant 0 : i8
// CHECK-NEXT: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
// CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64
// CHECK-NEXT: %[[VAL_3:.*]] = memref.alloca() : memref<1xi32>
// CHECK-NEXT: %[[VAL_4:.*]] = memref.alloca() : memref<1xf64>
// CHECK-NEXT: emitc.if %[[VAL_0]] {
// CHECK-NEXT: %[[VAL_5:.*]] = emitc.call_opaque "func_true_1"(%[[VAL_1]]) : (f32) -> i32
// CHECK-NEXT: %[[VAL_6:.*]] = emitc.call_opaque "func_true_2"(%[[VAL_1]]) : (f32) -> f64
// CHECK-NEXT: emitc.assign %[[VAL_5]] : i32 to %[[VAL_3]] : i32
// CHECK-NEXT: emitc.assign %[[VAL_6]] : f64 to %[[VAL_4]] : f64
// CHECK-NEXT: %[[VAL_7:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<1xi32>
// CHECK-NEXT: memref.store %[[VAL_6]], %[[VAL_4]]{{\[}}%[[VAL_7]]] : memref<1xf64>
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[VAL_7:.*]] = emitc.call_opaque "func_false_1"(%[[VAL_1]]) : (f32) -> i32
// CHECK-NEXT: %[[VAL_8:.*]] = emitc.call_opaque "func_false_2"(%[[VAL_1]]) : (f32) -> f64
// CHECK-NEXT: emitc.assign %[[VAL_7]] : i32 to %[[VAL_3]] : i32
// CHECK-NEXT: emitc.assign %[[VAL_8]] : f64 to %[[VAL_4]] : f64
// CHECK-NEXT: %[[VAL_8:.*]] = emitc.call_opaque "func_false_1"(%[[VAL_1]]) : (f32) -> i32
// CHECK-NEXT: %[[VAL_9:.*]] = emitc.call_opaque "func_false_2"(%[[VAL_1]]) : (f32) -> f64
// CHECK-NEXT: %[[VAL_10:.*]] = arith.constant 0 : index
// CHECK-NEXT: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_10]]] : memref<1xi32>
// CHECK-NEXT: memref.store %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_10]]] : memref<1xf64>
// CHECK-NEXT: }
// CHECK-NEXT: %[[VAL_11:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[VAL_12:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_11]]] : memref<1xi32>
// CHECK-NEXT: %[[VAL_13:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_11]]] : memref<1xf64>
// CHECK-NEXT: return
// CHECK-NEXT: }
Loading