@@ -1152,6 +1152,33 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
11521152 << PrimitiveType_Name (b_type);
11531153 return false ;
11541154 }
1155+ if (a_type == F8E4M3FN && b_type == F8E4M3FN) {
1156+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1157+ << " into FP8 Custom Call. For "
1158+ << rocm_compute_capability.gfx_version ()
1159+ << " arch, one of the input types must be F8E4M3FN, but got "
1160+ << PrimitiveType_Name (a_type) << " and "
1161+ << PrimitiveType_Name (b_type);
1162+ return false ;
1163+ }
1164+ if (a_type == F8E5M2 && b_type == F8E4M3FN) {
1165+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1166+ << " into FP8 Custom Call. For "
1167+ << rocm_compute_capability.gfx_version ()
1168+ << " arch, one of the input types must be F8E4M3FN, but got "
1169+ << PrimitiveType_Name (a_type) << " and "
1170+ << PrimitiveType_Name (b_type);
1171+ return false ;
1172+ }
1173+ if (a_type == F8E4M3FN && b_type == F8E4M3FN) {
1174+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1175+ << " into FP8 Custom Call. For "
1176+ << rocm_compute_capability.gfx_version ()
1177+ << " arch, one of the input types must be F8E4M3FN, but got "
1178+ << PrimitiveType_Name (a_type) << " and "
1179+ << PrimitiveType_Name (b_type);
1180+ return false ;
1181+ }
11551182 if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
11561183 (b_type != F8E5M2 && b_type != F8E4M3FN)) {
11571184 VLOG (1 )
@@ -2632,4 +2659,4 @@ absl::StatusOr<bool> GemmRewriter::Run(
26322659}
26332660
26342661} // namespace gpu
2635- } // namespace xla
2662+ } // namespace xla
0 commit comments