Skip to content

Commit 5e1571d

Browse files
committed
add axis check for elementwise op while the dimension of x is equal to the dimension of tensor
1 parent d9f59fd commit 5e1571d

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

paddle/fluid/operators/elementwise/elementwise_op.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ class ElementwiseOp : public framework::OperatorWithKernel {
8585
auto y_dims = ctx->GetInputDim("Y");
8686
int max_dim = std::max(x_dims.size(), y_dims.size());
8787
int axis = ctx->Attrs().Get<int>("axis");
88+
if (x_dims.size() == y_dims.size()) {
89+
PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), true,
90+
platform::errors::InvalidArgument(
91+
"axis should be -1 or 0 while the dimension of "
92+
"tensor X (%s) is equal to the dimension of "
93+
"tensor Y (%s), but received axis: %s",
94+
x_dims.size(), y_dims.size(), axis));
95+
}
8896
PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim), true,
8997
platform::errors::InvalidArgument(
9098
"The axis range must be [%s, %s), but axis is %s. "

0 commit comments

Comments
 (0)