Skip to content

Commit e84b2e9

Browse files
authored
Add bcast semantics checks at C++ level to BroadcastTensorsOp (#34874)
1 parent 28279f6 commit e84b2e9

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

paddle/fluid/operators/broadcast_tensors_op.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
3838

3939
int target_rank = 0;
4040
const auto& input_dims = ctx->GetInputsDim("X");
41+
4142
// 1. Find Output rank = max(Inputs rank)
4243
for (const auto& input_ddim : input_dims) {
4344
target_rank = std::max(target_rank, input_ddim.size());
@@ -64,6 +65,14 @@ class BroadcastTensorsOp : public framework::OperatorWithKernel {
6465
dim_size = input_ddim[axis];
6566
}
6667

68+
if (target_dim_size != 1 && dim_size != 1 &&
69+
target_dim_size != dim_size) {
70+
PADDLE_THROW(platform::errors::InvalidArgument(
71+
"BroadcastTensorsOp inputs does not satisfy bcast semantics,"
72+
"Please check axis = %d in reverse order",
73+
index));
74+
}
75+
6776
// We performed bcast semantics check at python level
6877
// So input tensors should all have legal shape
6978
target_dim_size = std::max(target_dim_size, dim_size);

python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,5 +192,47 @@ def test_bcast_semantics():
192192
self.assertRaises(TypeError, test_bcast_semantics)
193193

194194

195+
class TestRaiseBroadcastTensorsErrorDyGraph(unittest.TestCase):
196+
def test_errors(self):
197+
def test_type():
198+
inputs = [
199+
paddle.to_tensor(
200+
np.ones(
201+
shape=[1, 1, 1, 1], dtype='float32', name="x4")),
202+
paddle.to_tensor(
203+
np.ones(
204+
shape=[1, 4, 1, 1], dtype='float64', name="x5"))
205+
]
206+
paddle.broadcast_tensors(inputs)
207+
208+
def test_dtype():
209+
inputs = [
210+
paddle.to_tensor(
211+
np.ones(
212+
shape=[1, 1, 1, 1], dtype='int8', name="x6")),
213+
paddle.to_tensor(
214+
np.ones(
215+
shape=[1, 4, 1, 1], dtype='int8', name="x7"))
216+
]
217+
paddle.broadcast_tensors(inputs)
218+
219+
def test_bcast_semantics():
220+
inputs = [
221+
paddle.to_tensor(
222+
np.ones(
223+
shape=[1, 3, 1, 1], dtype='float32', name="x9")),
224+
paddle.to_tensor(
225+
np.ones(
226+
shape=[1, 8, 1, 1], dtype='float32', name="x10"))
227+
]
228+
paddle.broadcast_tensors(inputs)
229+
230+
paddle.disable_static()
231+
self.assertRaises(TypeError, test_type)
232+
self.assertRaises(TypeError, test_dtype)
233+
self.assertRaises(TypeError, test_bcast_semantics)
234+
paddle.enable_static()
235+
236+
195237
if __name__ == '__main__':
196238
unittest.main()

0 commit comments

Comments
 (0)