Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions xla/service/gpu/transforms/priority_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ namespace gpu {

namespace {

// Bitcasts are fusible if they don't change the bit width.
bool IsFusibleBitcast(const HloInstruction& instr) {
return instr.opcode() == HloOpcode::kBitcast &&
hlo_instruction_utils::KeepsBitwidth(instr);
}

bool IsFusible(const HloInstruction& instr) {
// Side-effecting operations are not fusible.
if (!instr.IsFusible()) {
Expand All @@ -92,13 +98,16 @@ bool IsFusible(const HloInstruction& instr) {
return true;
}

// Bitcasts are fusible if they don't change the bit width.
if (IsFusibleBitcast(instr)) {
return true;
}

// Other non-elementwise ops also supported by elemental fusion.
switch (instr.opcode()) {
case HloOpcode::kFusion:
return IsGenericTritonFusion(instr) ||
instr.fusion_kind() != HloInstruction::FusionKind::kCustom;
case HloOpcode::kBitcast:
return hlo_instruction_utils::KeepsBitwidth(instr);
case HloOpcode::kCopy:
case HloOpcode::kIota:
case HloOpcode::kConstant:
Expand Down Expand Up @@ -265,7 +274,7 @@ class PriorityFusionQueue {
current_consumers_ = {*preferred_consumer};
}

if (HloPredicateIsOp<HloOpcode::kBitcast>(current_producer_)) {
if (IsFusibleBitcast(*current_producer_)) {
// We don't check if bitcasts can be fused with all consumers, so we
// have to do it here.
llvm::erase_if(current_consumers_, [&](HloInstruction* consumer) {
Expand Down Expand Up @@ -547,7 +556,7 @@ class PriorityFusionQueue {
preferred_consumer_.erase(producer);
}
// Bitcasts should always be fused first, since they are no-ops.
if (HloPredicateIsOp<HloOpcode::kBitcast>(producer)) {
if (IsFusibleBitcast(*producer)) {
return absl::InfiniteDuration();
}
// We always fuse constants, but the cost model doesn't handle them very
Expand Down Expand Up @@ -790,7 +799,7 @@ class PriorityFusionQueue {
return can_fuse_triton;
}

if (HloPredicateIsOp<HloOpcode::kBitcast>(consumer)) {
if (IsFusibleBitcast(*consumer)) {
return FusionDecision::Forbid(
"not fusing into a single bitcast as consumer");
}
Expand Down Expand Up @@ -926,7 +935,7 @@ class PriorityFusionQueue {
}
std::vector<HloInstruction*> possible_consumers;
for (const auto& user : producer->users()) {
if (HloPredicateIsOp<HloOpcode::kBitcast>(user)) {
if (IsFusibleBitcast(*user)) {
continue;
}
if (CanFuseTriton(producer, user, /*use_multi_output_fusion=*/true) &&
Expand Down Expand Up @@ -960,7 +969,7 @@ class PriorityFusionQueue {

bool has_non_bitcast_user = false;
for (const auto& user : producer->users()) {
if (HloPredicateIsOp<HloOpcode::kBitcast>(user)) {
if (IsFusibleBitcast(*user)) {
continue;
}
has_non_bitcast_user = true;
Expand Down Expand Up @@ -1181,7 +1190,7 @@ absl::StatusOr<bool> PriorityFusion::Run(
for (auto* consumer : consumers) {
// Don't fuse into single bitcasts. We ignore them in the check
// CanFuseWithAllNonBitcastUsers(), so we need to check it here.
if (HloPredicateIsOp<HloOpcode::kBitcast>(consumer)) {
if (IsFusibleBitcast(*consumer)) {
continue;
}
if (!ConsumeFuel(producer, consumer)) {
Expand Down
24 changes: 13 additions & 11 deletions xla/service/gpu/transforms/priority_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,19 @@ CHECK-NEXT: ROOT %{{.*}} = (f32[512]{0}, s32[512]{0}) tuple(%[[FUSION_F32]], %[[
}

TEST_F(PriorityFusionTest, DoNotFuseBitWidthChangingBitcast) {
EXPECT_TRUE(RunAndCheckHloRewrite(R"(
e {
a = s8[3,5,2]{2,1,0} parameter(0)
n = s8[3,5,2]{2,1,0} negate(a)
b = s16[3,5]{1,0} bitcast(n)
m = s16[3,5]{1,0} multiply(b, b)
})",
std::move(priority_fusion_),
/*expect_change=*/false)
.status()
.ok());
// `neg` is the producer that could be fused with `bitcast` and `mul`, but
// since `bitcast` changes the bit width, we don't fuse it.
auto module = *ParseAndReturnVerifiedModule(R"(
ENTRY main {
p0 = s8[3,5,2]{2,1,0} parameter(0)
neg = s8[3,5,2]{2,1,0} negate(p0)
bitcast = s16[3,5]{1,0} bitcast(neg)
mul = s8[3,5,2]{2,1,0} add(neg, neg)
ROOT result = (s16[3,5]{1,0}, s8[3,5,2]{2,1,0}) tuple(bitcast, mul)
})");

EXPECT_THAT(priority_fusion_.Run(module.get()),
absl_testing::IsOkAndHolds(false));
}

TEST_F(PriorityFusionTest, FuseConvertIntoReduce) {
Expand Down