@@ -327,12 +327,13 @@ static int canonicalize(mlir::MLIRContext &context,
327327 optPM.addPass (mlir::createAffineScalarReplacementPass ());
328328 }
329329 if (mlir::failed (pm.run (module .get ()))) {
330- llvm::errs () << " Canonicalization failed. Module: ***\n " ;
330+ llvm::errs () << " *** Canonicalization failed. Module: ***\n " ;
331331 module ->dump ();
332332 return 4 ;
333333 }
334334 if (mlir::failed (mlir::verify (module .get ()))) {
335- llvm::errs () << " Verification after canonicalization failed. Module: ***\n " ;
335+ llvm::errs ()
336+ << " *** Verification after canonicalization failed. Module: ***\n " ;
336337 module ->dump ();
337338 return 5 ;
338339 }
@@ -384,12 +385,12 @@ static int optimize(mlir::MLIRContext &context,
384385 }
385386
386387 if (mlir::failed (pm.run (module .get ()))) {
387- llvm::errs () << " Optimize failed. Module: ***\n " ;
388+ llvm::errs () << " *** Optimize failed. Module: ***\n " ;
388389 module ->dump ();
389390 return 6 ;
390391 }
391392 if (mlir::failed (mlir::verify (module .get ()))) {
392- llvm::errs () << " Verification after optimization failed. Module: ***\n " ;
393+ llvm::errs () << " ** Verification after optimization failed. Module: ***\n " ;
393394 module ->dump ();
394395 return 7 ;
395396 }
@@ -462,12 +463,13 @@ static int optimizeCUDA(mlir::MLIRContext &context,
462463 noptPM2.addPass (mlir::createAffineScalarReplacementPass ());
463464 }
464465 if (mlir::failed (pm.run (module .get ()))) {
465- llvm::errs () << " Optimize CUDA failed. Module: ***\n " ;
466+ llvm::errs () << " *** Optimize CUDA failed. Module: ***\n " ;
466467 module ->dump ();
467468 return 8 ;
468469 }
469- if (mlir::failed (mlir::verify (module .get ()))) {
470- llvm::errs () << " Verification after CUDA optimization failed. Module: ***\n " ;
470+ if (mlir::failed (mlir::verify (module .get ()))) {
471+ llvm::errs ()
472+ << " *** Verification after CUDA optimization failed. Module: ***\n " ;
471473 module ->dump ();
472474 return 9 ;
473475 }
@@ -479,80 +481,89 @@ static int optimizeCUDA(mlir::MLIRContext &context,
479481 return 0 ;
480482}
481483
482- static int finalize (mlir::MLIRContext &context,
483- mlir::OwningOpRef<mlir::ModuleOp> &module ,
484- llvm::DataLayout &DL, bool &LinkOMP) {
484+ static void finalizeCUDA (mlir::PassManager &pm) {
485+ if (!CudaLower)
486+ return ;
487+
488+ mlir::OpPassManager &optPM = pm.nest <mlir::func::FuncOp>();
489+
485490 constexpr int unrollSize = 32 ;
486491 GreedyRewriteConfig canonicalizerConfig;
487492 canonicalizerConfig.maxIterations = CanonicalizeIterations;
488493
489- mlir::PassManager pm (&context);
490- mlir::OpPassManager &optPM = pm.nest <mlir::func::FuncOp>();
494+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
495+ optPM.addPass (mlir::createCSEPass ());
496+ optPM.addPass (polygeist::createMem2RegPass ());
497+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
498+ optPM.addPass (mlir::createCSEPass ());
499+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
500+ optPM.addPass (polygeist::createCanonicalizeForPass ());
501+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
491502
492- if (CudaLower) {
503+ if (RaiseToAffine) {
504+ optPM.addPass (polygeist::createCanonicalizeForPass ());
493505 optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
494- optPM.addPass (mlir::createCSEPass ());
495- optPM.addPass (polygeist::createMem2RegPass ());
506+ if (ParallelLICM)
507+ optPM.addPass (polygeist::createParallelLICMPass ());
508+ else
509+ optPM.addPass (mlir::createLoopInvariantCodeMotionPass ());
510+ optPM.addPass (polygeist::createRaiseSCFToAffinePass ());
496511 optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
497- optPM.addPass (mlir::createCSEPass ());
512+ optPM.addPass (polygeist::replaceAffineCFGPass ());
498513 optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
514+ if (ScalarReplacement)
515+ optPM.addPass (mlir::createAffineScalarReplacementPass ());
516+ }
517+ if (ToCPU == " continuation" ) {
518+ optPM.addPass (polygeist::createBarrierRemovalContinuation ());
519+ // pm.nest<mlir::FuncOp>().addPass(mlir::createCanonicalizerPass());
520+ } else if (ToCPU.size () != 0 ) {
521+ optPM.addPass (polygeist::createCPUifyPass (ToCPU));
522+ }
523+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
524+ optPM.addPass (mlir::createCSEPass ());
525+ optPM.addPass (polygeist::createMem2RegPass ());
526+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
527+ optPM.addPass (mlir::createCSEPass ());
528+ if (RaiseToAffine) {
499529 optPM.addPass (polygeist::createCanonicalizeForPass ());
500530 optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
501-
502- if (RaiseToAffine) {
503- optPM.addPass (polygeist::createCanonicalizeForPass ());
504- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
505- if (ParallelLICM)
506- optPM.addPass (polygeist::createParallelLICMPass ());
507- else
508- optPM.addPass (mlir::createLoopInvariantCodeMotionPass ());
509- optPM.addPass (polygeist::createRaiseSCFToAffinePass ());
510- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
511- optPM.addPass (polygeist::replaceAffineCFGPass ());
512- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
513- if (ScalarReplacement)
514- optPM.addPass (mlir::createAffineScalarReplacementPass ());
515- }
516- if (ToCPU == " continuation" ) {
517- optPM.addPass (polygeist::createBarrierRemovalContinuation ());
518- // pm.nest<mlir::FuncOp>().addPass(mlir::createCanonicalizerPass());
519- } else if (ToCPU.size () != 0 ) {
520- optPM.addPass (polygeist::createCPUifyPass (ToCPU));
521- }
531+ if (ParallelLICM)
532+ optPM.addPass (polygeist::createParallelLICMPass ());
533+ else
534+ optPM.addPass (mlir::createLoopInvariantCodeMotionPass ());
535+ optPM.addPass (polygeist::createRaiseSCFToAffinePass ());
536+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
537+ optPM.addPass (polygeist::replaceAffineCFGPass ());
538+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
539+ if (LoopUnroll)
540+ optPM.addPass (mlir::createLoopUnrollPass (unrollSize, false , true ));
522541 optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
523542 optPM.addPass (mlir::createCSEPass ());
524543 optPM.addPass (polygeist::createMem2RegPass ());
525544 optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
526- optPM.addPass (mlir::createCSEPass ());
527- if (RaiseToAffine) {
528- optPM.addPass (polygeist::createCanonicalizeForPass ());
529- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
530- if (ParallelLICM)
531- optPM.addPass (polygeist::createParallelLICMPass ());
532- else
533- optPM.addPass (mlir::createLoopInvariantCodeMotionPass ());
534- optPM.addPass (polygeist::createRaiseSCFToAffinePass ());
535- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
536- optPM.addPass (polygeist::replaceAffineCFGPass ());
537- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
538- if (LoopUnroll)
539- optPM.addPass (mlir::createLoopUnrollPass (unrollSize, false , true ));
540- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
541- optPM.addPass (mlir::createCSEPass ());
542- optPM.addPass (polygeist::createMem2RegPass ());
543- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
544- if (ParallelLICM)
545- optPM.addPass (polygeist::createParallelLICMPass ());
546- else
547- optPM.addPass (mlir::createLoopInvariantCodeMotionPass ());
548- optPM.addPass (polygeist::createRaiseSCFToAffinePass ());
549- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
550- optPM.addPass (polygeist::replaceAffineCFGPass ());
551- optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
552- if (ScalarReplacement)
553- optPM.addPass (mlir::createAffineScalarReplacementPass ());
554- }
545+ if (ParallelLICM)
546+ optPM.addPass (polygeist::createParallelLICMPass ());
547+ else
548+ optPM.addPass (mlir::createLoopInvariantCodeMotionPass ());
549+ optPM.addPass (polygeist::createRaiseSCFToAffinePass ());
550+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
551+ optPM.addPass (polygeist::replaceAffineCFGPass ());
552+ optPM.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
553+ if (ScalarReplacement)
554+ optPM.addPass (mlir::createAffineScalarReplacementPass ());
555555 }
556+ }
557+
558+ static int finalize (mlir::MLIRContext &context,
559+ mlir::OwningOpRef<mlir::ModuleOp> &module ,
560+ llvm::DataLayout &DL, bool &LinkOMP) {
561+ mlir::PassManager pm (&context);
562+ GreedyRewriteConfig canonicalizerConfig;
563+ canonicalizerConfig.maxIterations = CanonicalizeIterations;
564+
565+ finalizeCUDA (pm);
566+
556567 pm.addPass (mlir::createSymbolDCEPass ());
557568
558569 if (EmitLLVM || !EmitAssembly || EmitOpenMPIR) {
@@ -562,10 +573,21 @@ static int finalize(mlir::MLIRContext &context,
562573
563574 // pm.nest<mlir::FuncOp>().addPass(mlir::createConvertMathToLLVMPass());
564575 if (mlir::failed (pm.run (module .get ()))) {
565- llvm::errs () << " Finalize failed. Module: ***\n " ;
576+ llvm::errs () << " *** Finalize failed (phase 1) . Module: ***\n " ;
566577 module ->dump ();
567578 return 10 ;
568579 }
580+ if (mlir::failed (mlir::verify (module .get ()))) {
581+ llvm::errs () << " *** Verification after finalization failed (phase 1). "
582+ " Module: ***\n " ;
583+ module ->dump ();
584+ return 11 ;
585+ }
586+ LLVM_DEBUG ({
587+ llvm::dbgs () << " *** Module after finalize (phase 1) ***\n " ;
588+ module ->dump ();
589+ });
590+
569591 mlir::PassManager pm2 (&context);
570592 if (SCFOpenMP) {
571593 pm2.addPass (createConvertSCFToOpenMPPass ());
@@ -579,10 +601,21 @@ static int finalize(mlir::MLIRContext &context,
579601 pm2.addPass (mlir::createCSEPass ());
580602 pm2.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
581603 if (mlir::failed (pm2.run (module .get ()))) {
582- llvm::errs () << " Finalize failed. Module: ***\n " ;
604+ llvm::errs () << " *** Finalize failed (phase 2) . Module: ***\n " ;
583605 module ->dump ();
584- return 11 ;
606+ return 12 ;
607+ }
608+ if (mlir::failed (mlir::verify (module .get ()))) {
609+ llvm::errs () << " *** Verification after finalization failed (phase 2). "
610+ " Module: ***\n " ;
611+ module ->dump ();
612+ return 13 ;
585613 }
614+ LLVM_DEBUG ({
615+ llvm::dbgs () << " *** Module after finalize (phase 2) ***\n " ;
616+ module ->dump ();
617+ });
618+
586619 if (!EmitOpenMPIR) {
587620 module ->walk ([&](mlir::omp::ParallelOp) { LinkOMP = true ; });
588621 mlir::PassManager pm3 (&context);
@@ -594,29 +627,38 @@ static int finalize(mlir::MLIRContext &context,
594627 // pm3.addPass(mlir::createLowerFuncToLLVMPass(options));
595628 pm3.addPass (mlir::createCanonicalizerPass (canonicalizerConfig, {}, {}));
596629 if (mlir::failed (pm3.run (module .get ()))) {
597- llvm::errs () << " Finalize failed. Module: ***\n " ;
630+ llvm::errs () << " *** Finalize failed (phase 3). Module: ***\n " ;
631+ module ->dump ();
632+ return 14 ;
633+ }
634+ if (mlir::failed (mlir::verify (module .get ()))) {
635+ llvm::errs () << " Verification after finalization failed (phase 3). "
636+ " Module: ***\n " ;
598637 module ->dump ();
599- return 12 ;
638+ return 15 ;
600639 }
640+ LLVM_DEBUG ({
641+ llvm::dbgs () << " *** Module after finalize (phase 3) ***\n " ;
642+ module ->dump ();
643+ });
601644 }
602645 } else {
603646 if (mlir::failed (pm.run (module .get ()))) {
604- llvm::errs () << " Finalize failed. Module: ***\n " ;
647+ llvm::errs () << " *** Finalize failed. Module: ***\n " ;
605648 module ->dump ();
606- return 13 ;
649+ return 16 ;
607650 }
651+ if (mlir::failed (mlir::verify (module .get ()))) {
652+ llvm::errs () << " *** Verification after finalization failed. "
653+ " Module: ***\n " ;
654+ module ->dump ();
655+ return 17 ;
656+ }
657+ LLVM_DEBUG ({
658+ llvm::dbgs () << " *** Module after finalize ***\n " ;
659+ module ->dump ();
660+ });
608661 }
609-
610- if (mlir::failed (mlir::verify (module .get ()))) {
611- llvm::errs () << " Verification after finalization failed. Module: ***\n " ;
612- module ->dump ();
613- return 14 ;
614- }
615-
616- LLVM_DEBUG ({
617- llvm::dbgs () << " *** Module after finalize ***\n " ;
618- module ->dump ();
619- });
620662
621663 return 0 ;
622664}
0 commit comments