Skip to content

Commit 4ce9326

Browse files
loisloGoogle-ML-Automation
authored andcommitted
[XLA:GPU] Refactor Triton pass dumping and improve error logging.
Extracts the logic for setting up MLIR pass dumping to a helper class and uses it in both `CompileTritonToLLVM` and `LowerXTileToTriton`. The public API is minimized to a single function EnableIRPrintingIfRequested. PiperOrigin-RevId: 853236668
1 parent f22c103 commit 4ce9326

6 files changed

Lines changed: 287 additions & 79 deletions

File tree

xla/backends/gpu/codegen/triton/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,13 @@ cc_library(
320320
":lowering_util",
321321
":support",
322322
"//xla:autotuning_proto_cc",
323-
"//xla:status_macros",
324323
"//xla:util",
325324
"//xla:xla_data_proto_cc",
326325
"//xla:xla_proto_cc",
327326
"//xla/backends/gpu/codegen/emitters/ir:xla_gpu",
328327
"//xla/backends/gpu/codegen/triton/ir:triton_xla",
329328
"//xla/backends/gpu/codegen/triton/transforms:passes",
329+
"//xla/codegen:ir_printing",
330330
"//xla/codegen/emitters/ir:xla",
331331
"//xla/codegen/emitters/transforms:passes",
332332
"//xla/codegen/tiling:symbolic_tile_analysis",
@@ -383,7 +383,6 @@ cc_library(
383383
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
384384
"@llvm-project//mlir:Pass",
385385
"@llvm-project//mlir:ROCDLToLLVMIRTranslation",
386-
"@llvm-project//mlir:SCFToControlFlow",
387386
"@llvm-project//mlir:Support",
388387
"@llvm-project//mlir:TensorDialect",
389388
"@llvm-project//mlir:ToLLVMIRTranslation",

xla/backends/gpu/codegen/triton/triton_gemm_fusion_test.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,16 +388,25 @@ ENTRY e {
388388
}
389389
DebugOptions debug_options = verified_module->config().debug_options();
390390
debug_options.set_xla_dump_to(output_directory);
391-
debug_options.set_xla_dump_emitter_re("triton-fusion");
391+
debug_options.set_xla_dump_emitter_re("triton");
392392
verified_module->mutable_config().set_debug_options(debug_options);
393393

394394
EXPECT_TRUE(RunAndCompare(std::move(verified_module),
395395
ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
396396

397397
std::vector<std::string> paths;
398-
TF_EXPECT_OK(tsl::Env::Default()->GetMatchingPaths(
399-
tsl::io::JoinPath(output_directory, "*.triton-passes.log"), &paths));
398+
EXPECT_OK(tsl::Env::Default()->GetMatchingPaths(
399+
tsl::io::JoinPath(output_directory, "*.xtile-to-triton.txt"), &paths));
400400
EXPECT_EQ(paths.size(), 1);
401+
size_t file_size = 0;
402+
EXPECT_OK(tsl::Env::Default()->GetFileSize(paths[0], &file_size));
403+
EXPECT_GT(file_size, 10);
404+
EXPECT_OK(tsl::Env::Default()->GetMatchingPaths(
405+
tsl::io::JoinPath(output_directory, "*.triton-to-llvm.txt"), &paths));
406+
EXPECT_EQ(paths.size(), 1);
407+
file_size = 0;
408+
EXPECT_OK(tsl::Env::Default()->GetFileSize(paths[0], &file_size));
409+
EXPECT_GT(file_size, 10);
401410
}
402411

403412
TEST_F(TritonGemmTest, DotWithPredFromCompareProducesCorrectResult) {

xla/backends/gpu/codegen/triton/xtile_compiler.cc

Lines changed: 36 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ limitations under the License.
1717

1818
#include <cstdint>
1919
#include <memory>
20-
#include <optional>
2120
#include <string>
22-
#include <system_error> // NOLINT
2321
#include <utility>
2422
#include <vector>
2523

@@ -38,15 +36,13 @@ limitations under the License.
3836
#include "llvm/IR/Module.h"
3937
#include "llvm/Linker/Linker.h"
4038
#include "llvm/Support/Debug.h"
41-
#include "llvm/Support/FileSystem.h"
4239
#include "llvm/Support/LogicalResult.h"
4340
#include "llvm/Support/raw_ostream.h"
4441
#include "llvm/TargetParser/Triple.h"
4542
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
4643
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
4744
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
4845
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
49-
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
5046
#include "mlir/Dialect/Affine/IR/AffineOps.h"
5147
#include "mlir/Dialect/Arith/IR/Arith.h"
5248
#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"
@@ -69,7 +65,6 @@ limitations under the License.
6965
#include "mlir/IR/OwningOpRef.h"
7066
#include "mlir/IR/Value.h"
7167
#include "mlir/IR/Verifier.h"
72-
#include "mlir/Pass/Pass.h"
7368
#include "mlir/Pass/PassManager.h"
7469
#include "mlir/Support/LLVM.h"
7570
#include "mlir/Support/LogicalResult.h"
@@ -90,6 +85,7 @@ limitations under the License.
9085
#include "xla/backends/gpu/codegen/triton/transforms/passes.h"
9186
#include "xla/codegen/emitters/ir/xla_dialect.h"
9287
#include "xla/codegen/emitters/transforms/passes.h"
88+
#include "xla/codegen/ir_printing.h"
9389
#include "xla/codegen/xtile/ir/transforms/passes.h"
9490
#include "xla/codegen/xtile/ir/xtile_dialect.h"
9591
#include "xla/hlo/builder/xla_builder.h"
@@ -107,7 +103,6 @@ limitations under the License.
107103
#include "xla/service/gpu/model/triton_emitter_constraints.h"
108104
#include "xla/service/hlo_module_config.h"
109105
#include "xla/service/llvm_ir/llvm_util.h"
110-
#include "xla/status_macros.h"
111106
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
112107
#include "xla/stream_executor/device_description.h"
113108
#include "xla/stream_executor/gpu/tma_metadata.h"
@@ -119,7 +114,6 @@ limitations under the License.
119114
#include "xla/util.h"
120115
#include "xla/xla.pb.h"
121116
#include "xla/xla_data.pb.h"
122-
#include "tsl/platform/path.h"
123117
#include "triton/Dialect/Triton/IR/Dialect.h"
124118
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
125119

@@ -374,52 +368,10 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
374368
should_verify = true;
375369
#endif
376370

377-
bool should_dump_mlir_passes =
378-
hlo_config.debug_options().xla_enable_dumping() &&
379-
DumpingEnabledForHloModule(hlo_module) &&
380-
DumpingEnabledForEmitter("triton-fusion", hlo_config.debug_options());
381-
382371
mlir::PassManager pm(&mlir_context);
372+
EnableIRPrintingIfRequested(pm, &mlir_context, hlo_module, kernel_name,
373+
"triton-to-llvm");
383374
pm.enableVerifier(should_verify);
384-
385-
std::optional<llvm::raw_fd_ostream> log_stream;
386-
if (should_dump_mlir_passes) {
387-
std::string outputs_dir = hlo_config.debug_options().xla_dump_to();
388-
if (outputs_dir == "sponge") {
389-
if (!tsl::io::GetTestUndeclaredOutputsDir(&outputs_dir)) {
390-
LOG(ERROR) << "Failed to get test undeclared outputs dir. Lets skip "
391-
"dumping triton passes.";
392-
outputs_dir = "";
393-
}
394-
}
395-
if (!outputs_dir.empty()) {
396-
const std::string basename =
397-
absl::StrCat(absl::string_view(tsl::io::Basename(hlo_module.name())),
398-
".", kernel_name, ".triton-passes.log");
399-
std::string path = tsl::io::JoinPath(outputs_dir, basename);
400-
std::error_code err;
401-
log_stream.emplace(path, err, llvm::sys::fs::OF_None);
402-
if (err) {
403-
log_stream.reset();
404-
LOG(ERROR) << "Failed to dump triton passes to " << path << ": "
405-
<< err.message();
406-
} else {
407-
pm.getContext()->disableMultithreading();
408-
auto print_always = [](mlir::Pass*, mlir::Operation*) { return true; };
409-
pm.enableIRPrinting(/*shouldPrintBeforePass=*/print_always,
410-
/*shouldPrintAfterPass=*/print_always,
411-
/*printModuleScope=*/true,
412-
/*printAfterOnlyOnChange=*/false,
413-
/*printAfterOnlyOnFailure=*/true, *log_stream);
414-
}
415-
} else {
416-
LOG(ERROR)
417-
<< "--xla_dump_emitter_re=triton-fusion is set, but neither "
418-
<< "the environment variable TEST_UNDECLARED_OUTPUTS_DIR nor the "
419-
<< "flag --xla_dump_to is set, so the llvm dumps are disabled.";
420-
}
421-
}
422-
423375
CreateTritonXlaPipeline(&pm, gpu_cc, /*rewrite_int4=*/is_xla_fusion,
424376
block_level_parameters.is_tma_allowed,
425377
block_level_parameters.num_stages);
@@ -537,8 +489,11 @@ absl::Status LowerXTileToTriton(mlir::ModuleOp xtile_dialect_module,
537489
const HloFusionInstruction& fusion,
538490
const se::DeviceDescription& device_info) {
539491
{
492+
const HloModule& hlo_module = *fusion.GetModule();
540493
// Convert xTile ops to Triton ops.
541494
mlir::PassManager pm(&mlir_context);
495+
EnableIRPrintingIfRequested(pm, &mlir_context, hlo_module, fusion.name(),
496+
"xtile-to-triton");
542497
// Disable verifier because the Triton code may be invalid due to the
543498
// unsupported types.
544499
pm.enableVerifier(/*enabled=*/false);
@@ -566,33 +521,40 @@ absl::Status LowerXTileToTriton(mlir::ModuleOp xtile_dialect_module,
566521
}
567522
}
568523

569-
if (fusion.GetModule()
570-
->config()
571-
.debug_options()
572-
.xla_gpu_experimental_scaled_dot_with_triton()) {
573-
// Convert unsupported types before verification.
524+
{
525+
if (fusion.GetModule()
526+
->config()
527+
.debug_options()
528+
.xla_gpu_experimental_scaled_dot_with_triton()) {
529+
// Convert unsupported types before verification.
530+
mlir::PassManager pm(&mlir_context);
531+
532+
EnableIRPrintingIfRequested(pm, &mlir_context, *fusion.GetModule(),
533+
fusion.name(),
534+
"convert-scaled-dot-unsupported-types");
535+
pm.addPass(
536+
mlir::triton::xla::CreateTritonXLAConvertUnsupportedTypesPass());
537+
if (mlir::failed(pm.run(xtile_dialect_module))) {
538+
return CreateInternalError(
539+
"Failed to fix unsupported types in Triton module for fusion:",
540+
&fusion, xtile_dialect_module);
541+
}
542+
}
543+
544+
if (mlir::failed(mlir::verify(xtile_dialect_module))) {
545+
return CreateInternalError("Failed to verify Triton module for fusion:",
546+
&fusion, xtile_dialect_module);
547+
}
574548
mlir::PassManager pm(&mlir_context);
575-
pm.addPass(mlir::triton::xla::CreateTritonXLAConvertUnsupportedTypesPass());
549+
EnableIRPrintingIfRequested(pm, &mlir_context, *fusion.GetModule(),
550+
fusion.name(), "canonicalize-cse");
551+
pm.addPass(mlir::createCanonicalizerPass());
552+
pm.addPass(mlir::createCSEPass());
576553
if (mlir::failed(pm.run(xtile_dialect_module))) {
577-
return CreateInternalError(
578-
"Failed to fix unsupported types in Triton module for fusion:",
579-
&fusion, xtile_dialect_module);
554+
return CreateInternalError("Failed to create Triton module for fusion:",
555+
&fusion, xtile_dialect_module);
580556
}
581557
}
582-
583-
if (mlir::failed(mlir::verify(xtile_dialect_module))) {
584-
return CreateInternalError("Failed to verify Triton module for fusion:",
585-
&fusion, xtile_dialect_module);
586-
}
587-
mlir::PassManager pm(&mlir_context);
588-
589-
pm.addPass(mlir::createCanonicalizerPass());
590-
pm.addPass(mlir::createCSEPass());
591-
if (mlir::failed(pm.run(xtile_dialect_module))) {
592-
return CreateInternalError("Failed to create Triton module for fusion:",
593-
&fusion, xtile_dialect_module);
594-
}
595-
596558
return absl::OkStatus();
597559
}
598560

xla/codegen/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,24 @@ cc_library(
3535
],
3636
)
3737

38+
cc_library(
39+
name = "ir_printing",
40+
srcs = ["ir_printing.cc"],
41+
hdrs = ["ir_printing.h"],
42+
deps = [
43+
"//xla/hlo/ir:hlo",
44+
"//xla/service:dump",
45+
"//xla/service:hlo_module_config",
46+
"@com_google_absl//absl/log",
47+
"@com_google_absl//absl/strings",
48+
"@com_google_absl//absl/strings:string_view",
49+
"@llvm-project//llvm:Support",
50+
"@llvm-project//mlir:IR",
51+
"@llvm-project//mlir:Pass",
52+
"@tsl//tsl/platform:path",
53+
],
54+
)
55+
3856
cc_library(
3957
name = "kernel_spec",
4058
srcs = ["kernel_spec.cc"],

0 commit comments

Comments
 (0)