Skip to content

Commit bc1d816

Browse files
Added support for waves_per_eu function attribute. (#181)
1 parent d2c8356 commit bc1d816

5 files changed

Lines changed: 42 additions & 1 deletion

File tree

xla/service/gpu/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,11 @@ cc_library(
488488
],
489489
)
490490

491+
tf_proto_library(
492+
name = "triton_call_args_proto",
493+
srcs = ["triton_call_args.proto"],
494+
)
495+
491496
cc_library(
492497
name = "triton_call",
493498
srcs = if_gpu_is_configured(["triton_call.cc"]),
@@ -496,11 +501,14 @@ cc_library(
496501
"TENSORFLOW_USE_ROCM=1",
497502
]),
498503
deps = [
504+
":triton_call_args_proto_cc",
499505
"@com_google_absl//absl/strings:string_view",
500506
"@llvm-project//mlir:AsmParser",
501507
"@llvm-project//mlir:IR",
502508
"@llvm-project//mlir:Parser",
503509
"@llvm-project//mlir:Support",
510+
"@tsl//tsl/platform:logging",
511+
"@tsl//tsl/platform:protobuf",
504512
],
505513
)
506514

xla/service/gpu/ir_emitter_unnested.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,13 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
14561456
*ir_emitter_context_, sanitized_kernel_name,
14571457
kernel_arguments.args(), arg_size, launch_dimensions, &builder));
14581458

1459+
// If value for waves_per_eu is given create corresponding ROCm func attr
1460+
if (call.waves_per_eu != 0) {
1461+
// Default value - same as no value is given.
1462+
kernel->addFnAttr("amdgpu-waves-per-eu",
1463+
std::to_string(call.waves_per_eu));
1464+
}
1465+
14591466
// Move function body into kernel prototype.
14601467
llvm::Function* prototype_func = builder.GetInsertBlock()->getParent();
14611468
prototype_func->splice(prototype_func->begin(), impl_fn);

xla/service/gpu/triton_call.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ limitations under the License.
2424
#include "mlir/IR/MLIRContext.h"
2525
#include "mlir/Parser/Parser.h"
2626
#include "mlir/Support/LLVM.h"
27+
#include "xla/service/gpu/triton_call_args.pb.h"
28+
#include "tsl/platform/protobuf.h"
29+
#include "tsl/platform/logging.h"
2730

2831
namespace xla::gpu {
2932

@@ -44,7 +47,20 @@ TritonCall TritonCall::Parse(absl::string_view backend_config,
4447
attrs.getAs<mlir::IntegerAttr>("num_stages").getValue().getSExtValue();
4548
auto num_warps =
4649
attrs.getAs<mlir::IntegerAttr>("num_warps").getValue().getSExtValue();
47-
return TritonCall{std::move(name), std::move(ir), num_stages, num_warps,
50+
auto attr_smd = attrs.getAs<mlir::StringAttr>("serialized_metadata");
51+
int64_t waves_per_eu = 0;
52+
if (attr_smd) {
53+
TritonCallArgs triton_call_args_proto;
54+
auto sermetadata = attr_smd.getValue().str();
55+
if (tsl::protobuf::TextFormat::ParseFromString(
56+
sermetadata, &triton_call_args_proto)) {
57+
waves_per_eu = triton_call_args_proto.waves_per_eu();
58+
} else {
59+
// Parsing error: set default value
60+
waves_per_eu = 0;
61+
}
62+
}
63+
return TritonCall{std::move(name), std::move(ir), num_stages, num_warps, waves_per_eu,
4864
grid_x, grid_y, grid_z};
4965
}
5066

xla/service/gpu/triton_call.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct TritonCall {
2929
std::string ir;
3030
int64_t num_stages;
3131
int64_t num_warps;
32+
int64_t waves_per_eu;
3233
int32_t grid_x;
3334
int32_t grid_y;
3435
int32_t grid_z;
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
package xla.gpu;
4+
5+
// Arguments for triton calls for XLA:GPU.
6+
7+
message TritonCallArgs {
8+
optional int32 waves_per_eu = 1;
9+
}

0 commit comments

Comments
 (0)