diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 3cd9208bfbff55..ed23a69282327e 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -488,6 +488,11 @@ cc_library( ], ) +tf_proto_library( + name = "triton_call_args_proto", + srcs = ["triton_call_args.proto"], +) + cc_library( name = "triton_call", srcs = if_gpu_is_configured(["triton_call.cc"]), @@ -496,11 +501,14 @@ cc_library( "TENSORFLOW_USE_ROCM=1", ]), deps = [ + ":triton_call_args_proto_cc", "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index d16e7de8b78180..b5d4f012fda7f5 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1511,6 +1511,13 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall( *ir_emitter_context_, sanitized_kernel_name, kernel_arguments.args(), arg_size, launch_dimensions, &builder)); + // If value for waves_per_eu is given create corresponding ROCm func attr + if (call.waves_per_eu != 0) { + // Default value - same as no value is given. + kernel->addFnAttr("amdgpu-waves-per-eu", + std::to_string(call.waves_per_eu)); + } + // Move function body into kernel prototype. llvm::Function* prototype_func = builder.GetInsertBlock()->getParent(); prototype_func->splice(prototype_func->begin(), impl_fn); diff --git a/third_party/xla/xla/service/gpu/triton_call.cc b/third_party/xla/xla/service/gpu/triton_call.cc index 515145630ce4d4..99e807143f3cb1 100644 --- a/third_party/xla/xla/service/gpu/triton_call.cc +++ b/third_party/xla/xla/service/gpu/triton_call.cc @@ -24,6 +24,9 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/LLVM.h" +#include "xla/service/gpu/triton_call_args.pb.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/logging.h" namespace xla::gpu { @@ -44,7 +47,20 @@ TritonCall TritonCall::Parse(absl::string_view backend_config, attrs.getAs("num_stages").getValue().getSExtValue(); auto num_warps = attrs.getAs("num_warps").getValue().getSExtValue(); - return TritonCall{std::move(name), std::move(ir), num_stages, num_warps, + auto attr_smd = attrs.getAs("serialized_metadata"); + int64_t waves_per_eu = 0; + if (attr_smd) { + TritonCallArgs triton_call_args_proto; + auto sermetadata = attr_smd.getValue().str(); + if (tsl::protobuf::TextFormat::ParseFromString( + sermetadata, &triton_call_args_proto)) { + waves_per_eu = triton_call_args_proto.waves_per_eu(); + } else { + // Parsing error: set default value + waves_per_eu = 0; + } + } + return TritonCall{std::move(name), std::move(ir), num_stages, num_warps, waves_per_eu, grid_x, grid_y, grid_z}; } diff --git a/third_party/xla/xla/service/gpu/triton_call.h b/third_party/xla/xla/service/gpu/triton_call.h index d931bc93505a6e..0de8a9aacc8ea0 100644 --- a/third_party/xla/xla/service/gpu/triton_call.h +++ b/third_party/xla/xla/service/gpu/triton_call.h @@ -29,6 +29,7 @@ struct TritonCall { std::string ir; int64_t num_stages; int64_t num_warps; + int64_t waves_per_eu; int32_t grid_x; int32_t grid_y; int32_t grid_z; diff --git a/third_party/xla/xla/service/gpu/triton_call_args.proto b/third_party/xla/xla/service/gpu/triton_call_args.proto new file mode 100644 index 00000000000000..e65914a75f8bcd --- /dev/null +++ b/third_party/xla/xla/service/gpu/triton_call_args.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package xla.gpu; + +// Arguments for triton calls for XLA:GPU. + +message TritonCallArgs { + optional int32 waves_per_eu = 1; +} \ No newline at end of file