Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions xla/service/gpu/transforms/gemm_rewriter.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim,
permutation[batch_dim] = batch_dim;
}
// Identify the non-contracting dimension.
int non_contracting_dim;
// Initializing to -1 prevents “maybe-uninitialized”
int non_contracting_dim = -1;
for (int i = 0; i < input_shape.dimensions().size(); ++i) {
if (permutation[i] == -1 && contracting_dim != i) {
non_contracting_dim = i;
Expand Down Expand Up @@ -947,6 +948,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
GemmOrCublasLtMatmul(&existing_gemm).WithOneUser())
.WithOneUser()),
m::Op(&bias).WithPredicate(is_not_broadcast)))) {
// ROCm FP8: avoid turning Add into matrix-bias fusion (beta=1).
if (IsRocm(gpu_version_) &&
IsCublasLtMatmulF8(*existing_gemm)) {
VLOG(1) << "[GEMM REWRITE] Skip FuseMatrixBiasAdd on ROCm FP8; keep beta==0.";
return absl::OkStatus();
}
TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_backend_config,
existing_gemm->backend_config<GpuBackendConfig>());
const GemmBackendConfig &gemm_backend_config =
Expand Down Expand Up @@ -987,6 +994,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
.WithOneUser()))
.WithOneUser(),
m::Op(&bias).WithPredicate(is_not_broadcast)))) {
// ROCm FP8: avoid matrix-bias fusion (beta=1) on FP8 GEMMs.
if (IsRocm(gpu_version_) &&
IsCublasLtMatmulF8(*existing_gemm)) {
VLOG(1) << "[GEMM REWRITE] Skip FuseMatrixBiasAdd (slice/bitcast) on ROCm FP8; keep beta==0.";
return absl::OkStatus();
}
// The matrix bias must not be FP8, see
// https://docs.nvidia.com/cuda/cublas/index.html.
if (!IsF8Type(bias)) {
Expand Down Expand Up @@ -1642,6 +1655,13 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
const HloInstruction *gemm,
HloInstruction *bitcast = nullptr,
HloInstruction *slice = nullptr) {
// ROCm FP8: never fuse matrix bias to keep beta==0 for FP8 BF16/F16
// GEMMs. This avoids accumulation from C and improves numeric stability.
if (IsRocm(gpu_version_) && IsCublasLtMatmulF8(*gemm)) {
VLOG(1) << "[GEMM REWRITE] FuseMatrixBiasAdd disabled on ROCm FP8; "
"keeping bias outside GEMM (beta==0).";
return absl::OkStatus();
}
TF_RET_CHECK(Shape::Equal().IgnoreElementType()(bias->shape(),
bitcast ? bitcast->shape()
: slice ? slice->shape()
Expand All @@ -1652,7 +1672,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
if (gemm->shape().element_type() == S32) {
return absl::OkStatus();
}

// This avoids β=1 and the read of C inside GEMM.
if (IsCublasLtMatmulF8(*gemm)) {
VLOG(1) << "FP8 GEMM: skipping matrix-bias → beta fusion by default.";
return absl::OkStatus(); // Leave the Add as a separate op.
}
// To ensure correctness, only slices that chop off the ends of dimensions
// are supported.
if (slice) {
Expand Down Expand Up @@ -1691,8 +1715,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return in_out_alias_config.ParameterHasAlias(bias->parameter_number(),
/*param_index=*/{});
}();
bool want_to_fuse_bias = IsCublasLtMatmulF8(*gemm) ||
IsCublasLtMatmul(*gemm) || can_overwrite_bias;
bool want_to_fuse_bias = IsCublasLtMatmul(*gemm) || can_overwrite_bias;

auto gpu_config = gemm->backend_config<GpuBackendConfig>().value();
GemmBackendConfig &config = *gpu_config.mutable_gemm_backend_config();
Expand All @@ -1707,7 +1730,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return absl::OkStatus();
}

config.set_beta(1.0);
config.set_beta(1);

std::vector<HloInstruction *> operands(gemm->operands().begin(),
gemm->operands().end());
Expand Down Expand Up @@ -1745,7 +1768,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
// true if those uses all come before this operation. But copy-insertion
// runs before scheduling, so it can't know and has to conservatively insert
// copies.)
if (IsLegacyCublasMatmul(*fused_op) || can_overwrite_bias) {
if (IsLegacyCublasMatmul(*fused_op)) {
xla::Cast<HloCustomCallInstruction>(fused_op.get())
->set_output_to_operand_aliasing({{{}, {2, {}}}});
}
Expand Down