Skip to content

Commit 3b44991

Browse files
CopilotIvanYashchuk
andcommitted
Address PR feedback: simplify tests based on reviewer suggestions
Co-authored-by: IvanYashchuk <[email protected]>
1 parent cfcde08 commit 3b44991

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

tests/cpp/test_alias.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,9 +1670,12 @@ TEST_F(AliasTest, BroadcastInDimNoRedundantSet) {
16701670
fusion->addInput(in);
16711671

16721672
// Call broadcast with all dims marked as non-broadcast
1673-
// This should not introduce a Set operation
1673+
// This should not introduce a Set operation and return the input directly
16741674
std::vector<bool> is_broadcast_dim = {false, false};
1675-
TensorView* out = broadcast(in, is_broadcast_dim);
1675+
TensorView* maybe_bcast = broadcast(in, is_broadcast_dim);
1676+
1677+
// Add an operation to ensure we have something to test
1678+
TensorView* out = abs(maybe_bcast);
16761679

16771680
fusion->addOutput(out);
16781681

tests/python/test_python_frontend.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -707,20 +707,9 @@ def fusion_with_expand(fd: FusionDefinition):
707707
self.assertEqual(eager_out, nvf_out_exp[0])
708708

709709
# Check that the broadcast_in_dim fusion doesn't have a redundant Set operation
710-
# by comparing the IR string representations
711-
bid_str = str(fd_bid)
712-
exp_str = str(fd_exp)
713-
714-
# Count tensor cast operations (not scalar casts)
715-
bid_tensor_casts = bid_str.count("fd.ops.cast(t")
716-
exp_tensor_casts = exp_str.count("fd.ops.cast(t")
717-
718-
# They should have the same number of tensor casts
719-
self.assertEqual(
720-
bid_tensor_casts,
721-
exp_tensor_casts,
722-
f"broadcast_in_dim has {bid_tensor_casts} tensor casts but expand has {exp_tensor_casts}"
723-
)
710+
# by comparing the IR string representations - they should be identical since
711+
# broadcast is a no-op in this case
712+
self.assertEqual(str(fd_bid), str(fd_exp))
724713

725714
# Testing a scenario where the broadcast is necessary to realize the output
726715
def test_tensor_shape_with_output_bcast(self):

0 commit comments

Comments
 (0)