@@ -17,9 +17,7 @@ limitations under the License.
1717
1818#include < cstdint>
1919#include < memory>
20- #include < optional>
2120#include < string>
22- #include < system_error> // NOLINT
2321#include < utility>
2422#include < vector>
2523
@@ -38,15 +36,13 @@ limitations under the License.
3836#include " llvm/IR/Module.h"
3937#include " llvm/Linker/Linker.h"
4038#include " llvm/Support/Debug.h"
41- #include " llvm/Support/FileSystem.h"
4239#include " llvm/Support/LogicalResult.h"
4340#include " llvm/Support/raw_ostream.h"
4441#include " llvm/TargetParser/Triple.h"
4542#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
4643#include " mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
4744#include " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
4845#include " mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
49- #include " mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
5046#include " mlir/Dialect/Affine/IR/AffineOps.h"
5147#include " mlir/Dialect/Arith/IR/Arith.h"
5248#include " mlir/Dialect/Func/Extensions/InlinerExtension.h"
@@ -69,7 +65,6 @@ limitations under the License.
6965#include " mlir/IR/OwningOpRef.h"
7066#include " mlir/IR/Value.h"
7167#include " mlir/IR/Verifier.h"
72- #include " mlir/Pass/Pass.h"
7368#include " mlir/Pass/PassManager.h"
7469#include " mlir/Support/LLVM.h"
7570#include " mlir/Support/LogicalResult.h"
@@ -90,6 +85,7 @@ limitations under the License.
9085#include " xla/backends/gpu/codegen/triton/transforms/passes.h"
9186#include " xla/codegen/emitters/ir/xla_dialect.h"
9287#include " xla/codegen/emitters/transforms/passes.h"
88+ #include " xla/codegen/ir_printing.h"
9389#include " xla/codegen/xtile/ir/transforms/passes.h"
9490#include " xla/codegen/xtile/ir/xtile_dialect.h"
9591#include " xla/hlo/builder/xla_builder.h"
@@ -107,7 +103,6 @@ limitations under the License.
107103#include " xla/service/gpu/model/triton_emitter_constraints.h"
108104#include " xla/service/hlo_module_config.h"
109105#include " xla/service/llvm_ir/llvm_util.h"
110- #include " xla/status_macros.h"
111106#include " xla/stream_executor/cuda/cuda_compute_capability.h"
112107#include " xla/stream_executor/device_description.h"
113108#include " xla/stream_executor/gpu/tma_metadata.h"
@@ -119,7 +114,6 @@ limitations under the License.
119114#include " xla/util.h"
120115#include " xla/xla.pb.h"
121116#include " xla/xla_data.pb.h"
122- #include " tsl/platform/path.h"
123117#include " triton/Dialect/Triton/IR/Dialect.h"
124118#include " triton/Dialect/TritonGPU/IR/Dialect.h"
125119
@@ -374,52 +368,10 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
374368 should_verify = true ;
375369#endif
376370
377- bool should_dump_mlir_passes =
378- hlo_config.debug_options ().xla_enable_dumping () &&
379- DumpingEnabledForHloModule (hlo_module) &&
380- DumpingEnabledForEmitter (" triton-fusion" , hlo_config.debug_options ());
381-
382371 mlir::PassManager pm (&mlir_context);
372+ EnableIRPrintingIfRequested (pm, &mlir_context, hlo_module, kernel_name,
373+ " triton-to-llvm" );
383374 pm.enableVerifier (should_verify);
384-
385- std::optional<llvm::raw_fd_ostream> log_stream;
386- if (should_dump_mlir_passes) {
387- std::string outputs_dir = hlo_config.debug_options ().xla_dump_to ();
388- if (outputs_dir == " sponge" ) {
389- if (!tsl::io::GetTestUndeclaredOutputsDir (&outputs_dir)) {
390- LOG (ERROR) << " Failed to get test undeclared outputs dir. Lets skip "
391- " dumping triton passes." ;
392- outputs_dir = " " ;
393- }
394- }
395- if (!outputs_dir.empty ()) {
396- const std::string basename =
397- absl::StrCat (absl::string_view (tsl::io::Basename (hlo_module.name ())),
398- " ." , kernel_name, " .triton-passes.log" );
399- std::string path = tsl::io::JoinPath (outputs_dir, basename);
400- std::error_code err;
401- log_stream.emplace (path, err, llvm::sys::fs::OF_None);
402- if (err) {
403- log_stream.reset ();
404- LOG (ERROR) << " Failed to dump triton passes to " << path << " : "
405- << err.message ();
406- } else {
407- pm.getContext ()->disableMultithreading ();
408- auto print_always = [](mlir::Pass*, mlir::Operation*) { return true ; };
409- pm.enableIRPrinting (/* shouldPrintBeforePass=*/ print_always,
410- /* shouldPrintAfterPass=*/ print_always,
411- /* printModuleScope=*/ true ,
412- /* printAfterOnlyOnChange=*/ false ,
413- /* printAfterOnlyOnFailure=*/ true , *log_stream);
414- }
415- } else {
416- LOG (ERROR)
417- << " --xla_dump_emitter_re=triton-fusion is set, but neither "
418- << " the environment variable TEST_UNDECLARED_OUTPUTS_DIR nor the "
419- << " flag --xla_dump_to is set, so the llvm dumps are disabled." ;
420- }
421- }
422-
423375 CreateTritonXlaPipeline (&pm, gpu_cc, /* rewrite_int4=*/ is_xla_fusion,
424376 block_level_parameters.is_tma_allowed ,
425377 block_level_parameters.num_stages );
@@ -537,8 +489,11 @@ absl::Status LowerXTileToTriton(mlir::ModuleOp xtile_dialect_module,
537489 const HloFusionInstruction& fusion,
538490 const se::DeviceDescription& device_info) {
539491 {
492+ const HloModule& hlo_module = *fusion.GetModule ();
540493 // Convert xTile ops to Triton ops.
541494 mlir::PassManager pm (&mlir_context);
495+ EnableIRPrintingIfRequested (pm, &mlir_context, hlo_module, fusion.name (),
496+ " xtile-to-triton" );
542497 // Disable verifier because the Triton code may be invalid due to the
543498 // unsupported types.
544499 pm.enableVerifier (/* enabled=*/ false );
@@ -566,33 +521,40 @@ absl::Status LowerXTileToTriton(mlir::ModuleOp xtile_dialect_module,
566521 }
567522 }
568523
569- if (fusion.GetModule ()
570- ->config ()
571- .debug_options ()
572- .xla_gpu_experimental_scaled_dot_with_triton ()) {
573- // Convert unsupported types before verification.
524+ {
525+ if (fusion.GetModule ()
526+ ->config ()
527+ .debug_options ()
528+ .xla_gpu_experimental_scaled_dot_with_triton ()) {
529+ // Convert unsupported types before verification.
530+ mlir::PassManager pm (&mlir_context);
531+
532+ EnableIRPrintingIfRequested (pm, &mlir_context, *fusion.GetModule (),
533+ fusion.name (),
534+ " convert-scaled-dot-unsupported-types" );
535+ pm.addPass (
536+ mlir::triton::xla::CreateTritonXLAConvertUnsupportedTypesPass ());
537+ if (mlir::failed (pm.run (xtile_dialect_module))) {
538+ return CreateInternalError (
539+ " Failed to fix unsupported types in Triton module for fusion:" ,
540+ &fusion, xtile_dialect_module);
541+ }
542+ }
543+
544+ if (mlir::failed (mlir::verify (xtile_dialect_module))) {
545+ return CreateInternalError (" Failed to verify Triton module for fusion:" ,
546+ &fusion, xtile_dialect_module);
547+ }
574548 mlir::PassManager pm (&mlir_context);
575- pm.addPass (mlir::triton::xla::CreateTritonXLAConvertUnsupportedTypesPass ());
549+ EnableIRPrintingIfRequested (pm, &mlir_context, *fusion.GetModule (),
550+ fusion.name (), " canonicalize-cse" );
551+ pm.addPass (mlir::createCanonicalizerPass ());
552+ pm.addPass (mlir::createCSEPass ());
576553 if (mlir::failed (pm.run (xtile_dialect_module))) {
577- return CreateInternalError (
578- " Failed to fix unsupported types in Triton module for fusion:" ,
579- &fusion, xtile_dialect_module);
554+ return CreateInternalError (" Failed to create Triton module for fusion:" ,
555+ &fusion, xtile_dialect_module);
580556 }
581557 }
582-
583- if (mlir::failed (mlir::verify (xtile_dialect_module))) {
584- return CreateInternalError (" Failed to verify Triton module for fusion:" ,
585- &fusion, xtile_dialect_module);
586- }
587- mlir::PassManager pm (&mlir_context);
588-
589- pm.addPass (mlir::createCanonicalizerPass ());
590- pm.addPass (mlir::createCSEPass ());
591- if (mlir::failed (pm.run (xtile_dialect_module))) {
592- return CreateInternalError (" Failed to create Triton module for fusion:" ,
593- &fusion, xtile_dialect_module);
594- }
595-
596558 return absl::OkStatus ();
597559}
598560
0 commit comments