Skip to content

Commit fdb0fdc

Browse files
committed
Force FP8 gemms into F16 dot
1 parent 72cb133 commit fdb0fdc

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

xla/service/gpu/transforms/gemm_rewriter.cc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,33 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
11521152
<< PrimitiveType_Name(b_type);
11531153
return false;
11541154
}
1155+
if (a_type == F8E4M3FN && b_type == F8E4M3FN) {
1156+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1157+
<< " into FP8 Custom Call. For "
1158+
<< rocm_compute_capability.gfx_version()
1159+
<< " arch, one of the input types must be F8E4M3FN, but got "
1160+
<< PrimitiveType_Name(a_type) << " and "
1161+
<< PrimitiveType_Name(b_type);
1162+
return false;
1163+
}
1164+
if (a_type == F8E5M2 && b_type == F8E4M3FN) {
1165+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1166+
<< " into FP8 Custom Call. For "
1167+
<< rocm_compute_capability.gfx_version()
1168+
<< " arch, one of the input types must be F8E4M3FN, but got "
1169+
<< PrimitiveType_Name(a_type) << " and "
1170+
<< PrimitiveType_Name(b_type);
1171+
return false;
1172+
}
1173+
if (a_type == F8E4M3FN && b_type == F8E4M3FN) {
1174+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1175+
<< " into FP8 Custom Call. For "
1176+
<< rocm_compute_capability.gfx_version()
1177+
<< " arch, one of the input types must be F8E4M3FN, but got "
1178+
<< PrimitiveType_Name(a_type) << " and "
1179+
<< PrimitiveType_Name(b_type);
1180+
return false;
1181+
}
11551182
if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
11561183
(b_type != F8E5M2 && b_type != F8E4M3FN)) {
11571184
VLOG(1)
@@ -2632,4 +2659,4 @@ absl::StatusOr<bool> GemmRewriter::Run(
26322659
}
26332660

26342661
} // namespace gpu
2635-
} // namespace xla
2662+
} // namespace xla

0 commit comments

Comments
 (0)