@@ -2456,9 +2456,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
24562456 mlir::OperationState &result, mlir::Value lb,
24572457 mlir::Value ub, mlir::Value step, bool unordered,
24582458 bool finalCountValue, mlir::ValueRange iterArgs,
2459+ mlir::ValueRange reduceOperands,
2460+ llvm::ArrayRef<mlir::Attribute> reduceAttrs,
24592461 llvm::ArrayRef<mlir::NamedAttribute> attributes) {
24602462 result.addOperands ({lb, ub, step});
2463+ result.addOperands (reduceOperands);
24612464 result.addOperands (iterArgs);
2465+ result.addAttribute (getOperandSegmentSizeAttr (),
2466+ builder.getDenseI32ArrayAttr (
2467+ {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2468+ static_cast <int32_t >(iterArgs.size ())}));
24622469 if (finalCountValue) {
24632470 result.addTypes (builder.getIndexType ());
24642471 result.addAttribute (getFinalValueAttrName (result.name ),
@@ -2477,6 +2484,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
24772484 if (unordered)
24782485 result.addAttribute (getUnorderedAttrName (result.name ),
24792486 builder.getUnitAttr ());
2487+ if (!reduceAttrs.empty ())
2488+ result.addAttribute (getReduceAttrsAttrName (result.name ),
2489+ builder.getArrayAttr (reduceAttrs));
24802490 result.addAttributes (attributes);
24812491}
24822492
@@ -2502,24 +2512,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25022512 if (mlir::succeeded (parser.parseOptionalKeyword (" unordered" )))
25032513 result.addAttribute (" unordered" , builder.getUnitAttr ());
25042514
2515+ // Parse the reduction arguments.
2516+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
2517+ llvm::SmallVector<mlir::Type> reduceArgTypes;
2518+ if (succeeded (parser.parseOptionalKeyword (" reduce" ))) {
2519+ // Parse reduction attributes and variables.
2520+ llvm::SmallVector<ReduceAttr> attributes;
2521+ if (failed (parser.parseCommaSeparatedList (
2522+ mlir::AsmParser::Delimiter::Paren, [&]() {
2523+ if (parser.parseAttribute (attributes.emplace_back ()) ||
2524+ parser.parseArrow () ||
2525+ parser.parseOperand (reduceOperands.emplace_back ()) ||
2526+ parser.parseColonType (reduceArgTypes.emplace_back ()))
2527+ return mlir::failure ();
2528+ return mlir::success ();
2529+ })))
2530+ return mlir::failure ();
2531+ // Resolve input operands.
2532+ for (auto operand_type : llvm::zip (reduceOperands, reduceArgTypes))
2533+ if (parser.resolveOperand (std::get<0 >(operand_type),
2534+ std::get<1 >(operand_type), result.operands ))
2535+ return mlir::failure ();
2536+ llvm::SmallVector<mlir::Attribute> arrayAttr (attributes.begin (),
2537+ attributes.end ());
2538+ result.addAttribute (getReduceAttrsAttrName (result.name ),
2539+ builder.getArrayAttr (arrayAttr));
2540+ }
2541+
25052542 // Parse the optional initial iteration arguments.
25062543 llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2507- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands ;
2544+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands ;
25082545 llvm::SmallVector<mlir::Type> argTypes;
25092546 bool prependCount = false ;
25102547 regionArgs.push_back (inductionVariable);
25112548
25122549 if (succeeded (parser.parseOptionalKeyword (" iter_args" ))) {
25132550 // Parse assignment list and results type list.
2514- if (parser.parseAssignmentList (regionArgs, operands ) ||
2551+ if (parser.parseAssignmentList (regionArgs, iterOperands ) ||
25152552 parser.parseArrowTypeList (result.types ))
25162553 return mlir::failure ();
2517- if (result.types .size () == operands .size () + 1 )
2554+ if (result.types .size () == iterOperands .size () + 1 )
25182555 prependCount = true ;
25192556 // Resolve input operands.
25202557 llvm::ArrayRef<mlir::Type> resTypes = result.types ;
2521- for (auto operand_type :
2522- llvm::zip (operands , prependCount ? resTypes.drop_front () : resTypes))
2558+ for (auto operand_type : llvm::zip (
2559+ iterOperands , prependCount ? resTypes.drop_front () : resTypes))
25232560 if (parser.resolveOperand (std::get<0 >(operand_type),
25242561 std::get<1 >(operand_type), result.operands ))
25252562 return mlir::failure ();
@@ -2530,6 +2567,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25302567 prependCount = true ;
25312568 }
25322569
2570+ // Set the operandSegmentSizes attribute
2571+ result.addAttribute (getOperandSegmentSizeAttr (),
2572+ builder.getDenseI32ArrayAttr (
2573+ {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2574+ static_cast <int32_t >(iterOperands.size ())}));
2575+
25332576 if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
25342577 return mlir::failure ();
25352578
@@ -2606,6 +2649,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
26062649
26072650 i++;
26082651 }
2652+ auto reduceAttrs = getReduceAttrsAttr ();
2653+ if (getNumReduceOperands () != (reduceAttrs ? reduceAttrs.size () : 0 ))
2654+ return emitOpError (
2655+ " mismatch in number of reduction variables and reduction attributes" );
26092656 return mlir::success ();
26102657}
26112658
@@ -2615,6 +2662,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
26152662 << getUpperBound () << " step " << getStep ();
26162663 if (getUnordered ())
26172664 p << " unordered" ;
2665+ if (hasReduceOperands ()) {
2666+ p << " reduce(" ;
2667+ auto attrs = getReduceAttrsAttr ();
2668+ auto operands = getReduceOperands ();
2669+ llvm::interleaveComma (llvm::zip (attrs, operands), p, [&](auto it) {
2670+ p << std::get<0 >(it) << " -> " << std::get<1 >(it) << " : "
2671+ << std::get<1 >(it).getType ();
2672+ });
2673+ p << ' )' ;
2674+ printBlockTerminators = true ;
2675+ }
26182676 if (hasIterOperands ()) {
26192677 p << " iter_args(" ;
26202678 auto regionArgs = getRegionIterArgs ();
@@ -2628,8 +2686,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
26282686 p << " -> " << getResultTypes ();
26292687 printBlockTerminators = true ;
26302688 }
2631- p.printOptionalAttrDictWithKeyword ((*this )->getAttrs (),
2632- {" unordered" , " finalValue" });
2689+ p.printOptionalAttrDictWithKeyword (
2690+ (*this )->getAttrs (),
2691+ {" unordered" , " finalValue" , " reduceAttrs" , " operandSegmentSizes" });
26332692 p << ' ' ;
26342693 p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false ,
26352694 printBlockTerminators);
0 commit comments