@@ -96,6 +96,21 @@ def test_api(self):
9696 fetch_list = [out ])
9797 self .assertEqual ((res == self .real_result ).all (), True )
9898
99+ def test_api_float (self ):
100+ if self .op_type == "equal" :
101+ paddle .enable_static ()
102+ with program_guard (Program (), Program ()):
103+ x = fluid .data (name = 'x' , shape = [4 ], dtype = 'int64' )
104+ y = fluid .data (name = 'y' , shape = [1 ], dtype = 'int64' )
105+ op = eval ("paddle.%s" % (self .op_type ))
106+ out = op (x , y )
107+ exe = fluid .Executor (self .place )
108+ res , = exe .run (feed = {"x" : self .input_x ,
109+ "y" : 1.0 },
110+ fetch_list = [out ])
111+ self .real_result = np .array ([1 , 0 , 0 , 0 ]).astype (np .int64 )
112+ self .assertEqual ((res == self .real_result ).all (), True )
113+
99114 def test_dynamic_api (self ):
100115 paddle .disable_static ()
101116 x = paddle .to_tensor (self .input_x )
@@ -105,6 +120,47 @@ def test_dynamic_api(self):
105120 self .assertEqual ((out .numpy () == self .real_result ).all (), True )
106121 paddle .enable_static ()
107122
123+ def test_dynamic_api_int (self ):
124+ if self .op_type == "equal" :
125+ paddle .disable_static ()
126+ x = paddle .to_tensor (self .input_x )
127+ op = eval ("paddle.%s" % (self .op_type ))
128+ out = op (x , 1 )
129+ self .real_result = np .array ([1 , 0 , 0 , 0 ]).astype (np .int64 )
130+ self .assertEqual ((out .numpy () == self .real_result ).all (), True )
131+ paddle .enable_static ()
132+
133+ def test_dynamic_api_float (self ):
134+ if self .op_type == "equal" :
135+ paddle .disable_static ()
136+ x = paddle .to_tensor (self .input_x )
137+ op = eval ("paddle.%s" % (self .op_type ))
138+ out = op (x , 1.0 )
139+ self .real_result = np .array ([1 , 0 , 0 , 0 ]).astype (np .int64 )
140+ self .assertEqual ((out .numpy () == self .real_result ).all (), True )
141+ paddle .enable_static ()
142+
143+ def test_assert (self ):
144+ def test_dynamic_api_string (self ):
145+ if self .op_type == "equal" :
146+ paddle .disable_static ()
147+ x = paddle .to_tensor (self .input_x )
148+ op = eval ("paddle.%s" % (self .op_type ))
149+ out = op (x , "1.0" )
150+ paddle .enable_static ()
151+
152+ self .assertRaises (TypeError , test_dynamic_api_string )
153+
154+ def test_dynamic_api_bool (self ):
155+ if self .op_type == "equal" :
156+ paddle .disable_static ()
157+ x = paddle .to_tensor (self .input_x )
158+ op = eval ("paddle.%s" % (self .op_type ))
159+ out = op (x , True )
160+ self .real_result = np .array ([1 , 0 , 0 , 0 ]).astype (np .int64 )
161+ self .assertEqual ((out .numpy () == self .real_result ).all (), True )
162+ paddle .enable_static ()
163+
108164 def test_broadcast_api_1 (self ):
109165 paddle .enable_static ()
110166 with program_guard (Program (), Program ()):
0 commit comments