Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions python/paddle/fluid/tests/unittests/test_compare_op.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ def test_api(self):
fetch_list=[out])
self.assertEqual((res == self.real_result).all(), True)

def test_api_float(self):
if self.op_type == "equal":
paddle.enable_static()
with program_guard(Program(), Program()):
x = fluid.data(name='x', shape=[4], dtype='int64')
y = fluid.data(name='y', shape=[1], dtype='int64')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
exe = fluid.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"y": 1.0},
fetch_list=[out])
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((res == self.real_result).all(), True)

def test_dynamic_api(self):
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
Expand All @@ -105,6 +120,47 @@ def test_dynamic_api(self):
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_dynamic_api_int(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, 1)
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_dynamic_api_float(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, 1.0)
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_assert(self):
def test_dynamic_api_string(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, "1.0")
paddle.enable_static()

self.assertRaises(TypeError, test_dynamic_api_string)

def test_dynamic_api_bool(self):
if self.op_type == "equal":
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
op = eval("paddle.%s" % (self.op_type))
out = op(x, True)
self.real_result = np.array([1, 0, 0, 0]).astype(np.int64)
self.assertEqual((out.numpy() == self.real_result).all(), True)
paddle.enable_static()

def test_broadcast_api_1(self):
paddle.enable_static()
with program_guard(Program(), Program()):
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/tensor/logic.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from paddle.common_ops_import import core
from paddle import _C_ops
from paddle.tensor.creation import full

__all__ = []

Expand Down Expand Up @@ -174,6 +175,13 @@ def equal(x, y, name=None):
result1 = paddle.equal(x, y)
print(result1) # result1 = [True False False]
"""
if not isinstance(y, (int, bool, float, Variable)):
raise TypeError(
"Type of input args must be float, bool, int or Tensor, but received type {}".
format(type(y)))
if not isinstance(y, Variable):
y = full(shape=[1], dtype=x.dtype, fill_value=y)

if in_dygraph_mode():
return _C_ops.equal(x, y)

Expand Down