@@ -139,6 +139,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
139139 cl::desc (" Run the verifier after each transformation pass" ),
140140 cl::location (verifyPassesFlag), cl::init (true ));
141141
142+ static cl::opt<bool , /* ExternalStorage=*/ true > verifyRoundtrip (
143+ " verify-roundtrip" ,
144+ cl::desc (" Round-trip the IR after parsing and ensure it succeeds" ),
145+ cl::location (verifyRoundtripFlag), cl::init (false ));
146+
142147 static cl::list<std::string> passPlugins (
143148 " load-pass-plugin" , cl::desc (" Load passes from plugin library" ));
144149 // / Set the callback to load a pass plugin.
@@ -213,6 +218,104 @@ void MlirOptMainConfigCLOptions::setDialectPluginsCallback(
213218 });
214219}
215220
221+ LogicalResult loadIRDLDialects (StringRef irdlFile, MLIRContext &ctx) {
222+ DialectRegistry registry;
223+ registry.insert <irdl::IRDLDialect>();
224+ ctx.appendDialectRegistry (registry);
225+
226+ // Set up the input file.
227+ std::string errorMessage;
228+ std::unique_ptr<MemoryBuffer> file = openInputFile (irdlFile, &errorMessage);
229+ if (!file) {
230+ emitError (UnknownLoc::get (&ctx)) << errorMessage;
231+ return failure ();
232+ }
233+
234+ // Give the buffer to the source manager.
235+ // This will be picked up by the parser.
236+ SourceMgr sourceMgr;
237+ sourceMgr.AddNewSourceBuffer (std::move (file), SMLoc ());
238+
239+ SourceMgrDiagnosticHandler sourceMgrHandler (sourceMgr, &ctx);
240+
241+ // Parse the input file.
242+ OwningOpRef<ModuleOp> module (parseSourceFile<ModuleOp>(sourceMgr, &ctx));
243+
244+ // Load IRDL dialects.
245+ return irdl::loadDialects (module .get ());
246+ }
247+
248+ // Return success if the module can correctly round-trip. This intended to test
249+ // that the custom printers/parsers are complete.
250+ static LogicalResult doVerifyRoundTrip (Operation *op,
251+ const MlirOptMainConfig &config,
252+ bool useBytecode) {
253+ // We use a new context to avoid resource handle renaming issue in the diff.
254+ MLIRContext roundtripContext;
255+ OwningOpRef<Operation *> roundtripModule;
256+ roundtripContext.appendDialectRegistry (
257+ op->getContext ()->getDialectRegistry ());
258+ if (op->getContext ()->allowsUnregisteredDialects ())
259+ roundtripContext.allowUnregisteredDialects ();
260+ StringRef irdlFile = config.getIrdlFile ();
261+ if (!irdlFile.empty () && failed (loadIRDLDialects (irdlFile, roundtripContext)))
262+ return failure ();
263+
264+ // Print a first time with custom format (or bytecode) and parse it back to
265+ // the roundtripModule.
266+ {
267+ std::string buffer;
268+ llvm::raw_string_ostream ostream (buffer);
269+ if (useBytecode) {
270+ if (failed (writeBytecodeToFile (op, ostream))) {
271+ op->emitOpError () << " failed to write bytecode, cannot verify round-trip.\n " ;
272+ return failure ();
273+ }
274+ } else {
275+ op->print (ostream,
276+ OpPrintingFlags ().printGenericOpForm (false ).enableDebugInfo ());
277+ }
278+ FallbackAsmResourceMap fallbackResourceMap;
279+ ParserConfig parseConfig (&roundtripContext, /* verifyAfterParse=*/ true ,
280+ &fallbackResourceMap);
281+ roundtripModule =
282+ parseSourceString<Operation *>(ostream.str (), parseConfig);
283+ if (!roundtripModule) {
284+ op->emitOpError () << " failed to parse bytecode back, cannot verify round-trip.\n " ;
285+ return failure ();
286+ }
287+ }
288+
289+ // Print in the generic form for the reference module and the round-tripped
290+ // one and compare the outputs.
291+ std::string reference, roundtrip;
292+ {
293+ llvm::raw_string_ostream ostreamref (reference);
294+ op->print (ostreamref,
295+ OpPrintingFlags ().printGenericOpForm ().enableDebugInfo ());
296+ llvm::raw_string_ostream ostreamrndtrip (roundtrip);
297+ roundtripModule.get ()->print (
298+ ostreamrndtrip,
299+ OpPrintingFlags ().printGenericOpForm ().enableDebugInfo ());
300+ }
301+ if (reference != roundtrip) {
302+ // TODO implement a diff.
303+ return op->emitOpError () << " roundTrip testing roundtripped module differs from reference:\n <<<<<<Reference\n "
304+ << reference << " \n =====\n "
305+ << roundtrip << " \n >>>>>roundtripped\n " ;
306+ }
307+
308+ return success ();
309+ }
310+
311+ static LogicalResult doVerifyRoundTrip (Operation *op,
312+ const MlirOptMainConfig &config) {
313+ // Textual round-trip isn't fully robust at the moment (for example implicit
314+ // terminator are losing location informations).
315+
316+ return doVerifyRoundTrip (op, config, /* useBytecode=*/ true );
317+ }
318+
216319// / Perform the actions on the input file indicated by the command line flags
217320// / within the specified context.
218321// /
@@ -247,10 +350,16 @@ performActions(raw_ostream &os,
247350 TimingScope parserTiming = timing.nest (" Parser" );
248351 OwningOpRef<Operation *> op = parseSourceFileForTool (
249352 sourceMgr, parseConfig, !config.shouldUseExplicitModule ());
250- context-> enableMultithreading (wasThreadingEnabled );
353+ parserTiming. stop ( );
251354 if (!op)
252355 return failure ();
253- parserTiming.stop ();
356+
357+ // Perform round-trip verification if requested
358+ if (config.shouldVerifyRoundtrip () &&
359+ failed (doVerifyRoundTrip (op.get (), config)))
360+ return failure ();
361+
362+ context->enableMultithreading (wasThreadingEnabled);
254363
255364 // Prepare the pass manager, applying command-line and reproducer options.
256365 PassManager pm (op.get ()->getName (), PassManager::Nesting::Implicit);
@@ -286,33 +395,6 @@ performActions(raw_ostream &os,
286395 return success ();
287396}
288397
289- LogicalResult loadIRDLDialects (StringRef irdlFile, MLIRContext &ctx) {
290- DialectRegistry registry;
291- registry.insert <irdl::IRDLDialect>();
292- ctx.appendDialectRegistry (registry);
293-
294- // Set up the input file.
295- std::string errorMessage;
296- std::unique_ptr<MemoryBuffer> file = openInputFile (irdlFile, &errorMessage);
297- if (!file) {
298- emitError (UnknownLoc::get (&ctx)) << errorMessage;
299- return failure ();
300- }
301-
302- // Give the buffer to the source manager.
303- // This will be picked up by the parser.
304- SourceMgr sourceMgr;
305- sourceMgr.AddNewSourceBuffer (std::move (file), SMLoc ());
306-
307- SourceMgrDiagnosticHandler sourceMgrHandler (sourceMgr, &ctx);
308-
309- // Parse the input file.
310- OwningOpRef<ModuleOp> module (parseSourceFile<ModuleOp>(sourceMgr, &ctx));
311-
312- // Load IRDL dialects.
313- return irdl::loadDialects (module .get ());
314- }
315-
316398// / Parses the memory buffer. If successfully, run a series of passes against
317399// / it and print the result.
318400static LogicalResult processBuffer (raw_ostream &os,
0 commit comments