1- #Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2- #
1+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+ #
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
55# You may obtain a copy of the License at
6- #
6+ #
77# http://www.apache.org/licenses/LICENSE-2.0
8- #
8+ #
99# Unless required by applicable law or agreed to in writing, software
1010# distributed under the License is distributed on an "AS IS" BASIS,
1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
1515from __future__ import print_function
16-
1716import unittest
1817import numpy as np
1918import paddle
2423from paddle .fluid import compiler , Program , program_guard
2524from paddle .fluid .op import Operator
2625from paddle .fluid .backward import append_backward
26+ from paddle .fluid .framework import _test_eager_guard
2727
2828
2929class TestWhereOp (OpTest ):
3030 def setUp (self ):
31- self .op_type = " where"
31+ self .op_type = ' where'
3232 self .init_config ()
3333 self .inputs = {'Condition' : self .cond , 'X' : self .x , 'Y' : self .y }
3434 self .outputs = {'Out' : np .where (self .cond , self .x , self .y )}
3535
3636 def test_check_output (self ):
37- self .check_output ()
37+ self .check_output (check_eager = True )
3838
3939 def test_check_grad (self ):
40- self .check_grad (['X' , 'Y' ], 'Out' )
40+ self .check_grad (['X' , 'Y' ], 'Out' , check_eager = True )
4141
4242 def init_config (self ):
43- self .x = np .random .uniform (- 3 , 5 , ( 100 )) .astype (" float64" )
44- self .y = np .random .uniform (- 3 , 5 , ( 100 )) .astype (" float64" )
45- self .cond = np .zeros (( 100 )) .astype (" bool" )
43+ self .x = np .random .uniform (( - 3 ) , 5 , 100 ).astype (' float64' )
44+ self .y = np .random .uniform (( - 3 ) , 5 , 100 ).astype (' float64' )
45+ self .cond = np .zeros (100 ).astype (' bool' )
4646
4747
4848class TestWhereOp2 (TestWhereOp ):
4949 def init_config (self ):
50- self .x = np .random .uniform (- 5 , 5 , (60 , 2 )).astype (" float64" )
51- self .y = np .random .uniform (- 5 , 5 , (60 , 2 )).astype (" float64" )
52- self .cond = np .ones ((60 , 2 )).astype (" bool" )
50+ self .x = np .random .uniform (( - 5 ) , 5 , (60 , 2 )).astype (' float64' )
51+ self .y = np .random .uniform (( - 5 ) , 5 , (60 , 2 )).astype (' float64' )
52+ self .cond = np .ones ((60 , 2 )).astype (' bool' )
5353
5454
5555class TestWhereOp3 (TestWhereOp ):
5656 def init_config (self ):
57- self .x = np .random .uniform (- 3 , 5 , (20 , 2 , 4 )).astype (" float64" )
58- self .y = np .random .uniform (- 3 , 5 , (20 , 2 , 4 )).astype (" float64" )
57+ self .x = np .random .uniform (( - 3 ) , 5 , (20 , 2 , 4 )).astype (' float64' )
58+ self .y = np .random .uniform (( - 3 ) , 5 , (20 , 2 , 4 )).astype (' float64' )
5959 self .cond = np .array (np .random .randint (2 , size = (20 , 2 , 4 )), dtype = bool )
6060
6161
@@ -66,15 +66,15 @@ def setUp(self):
6666 def init_data (self ):
6767 self .shape = [10 , 15 ]
6868 self .cond = np .array (np .random .randint (2 , size = self .shape ), dtype = bool )
69- self .x = np .random .uniform (- 2 , 3 , self .shape ).astype (np .float32 )
70- self .y = np .random .uniform (- 2 , 3 , self .shape ).astype (np .float32 )
69+ self .x = np .random .uniform (( - 2 ) , 3 , self .shape ).astype (np .float32 )
70+ self .y = np .random .uniform (( - 2 ) , 3 , self .shape ).astype (np .float32 )
7171 self .out = np .where (self .cond , self .x , self .y )
7272
7373 def ref_x_backward (self , dout ):
74- return np .where (self .cond == True , dout , 0 )
74+ return np .where (( self .cond == True ) , dout , 0 )
7575
7676 def ref_y_backward (self , dout ):
77- return np .where (self .cond == False , dout , 0 )
77+ return np .where (( self .cond == False ) , dout , 0 )
7878
7979 def test_api (self , use_cuda = False ):
8080 for x_stop_gradient in [False , True ]:
@@ -90,17 +90,17 @@ def test_api(self, use_cuda=False):
9090 y .stop_gradient = y_stop_gradient
9191 result = paddle .where (cond , x , y )
9292 append_backward (layers .mean (result ))
93-
9493 for use_cuda in [False , True ]:
95- if use_cuda and not fluid .core .is_compiled_with_cuda ():
94+ if (use_cuda and
95+ (not fluid .core .is_compiled_with_cuda ())):
9696 break
97- place = fluid .CUDAPlace (
98- 0 ) if use_cuda else fluid .CPUPlace ()
97+ place = ( fluid .CUDAPlace (0 )
98+ if use_cuda else fluid .CPUPlace () )
9999 exe = fluid .Executor (place )
100100 fetch_list = [result , result .grad_name ]
101- if x_stop_gradient is False :
101+ if ( x_stop_gradient is False ) :
102102 fetch_list .append (x .grad_name )
103- if y_stop_gradient is False :
103+ if ( y_stop_gradient is False ) :
104104 fetch_list .append (y .grad_name )
105105 out = exe .run (
106106 fluid .default_main_program (),
@@ -109,13 +109,13 @@ def test_api(self, use_cuda=False):
109109 'y' : self .y },
110110 fetch_list = fetch_list )
111111 assert np .array_equal (out [0 ], self .out )
112- if x_stop_gradient is False :
112+ if ( x_stop_gradient is False ) :
113113 assert np .array_equal (out [2 ],
114114 self .ref_x_backward (out [1 ]))
115- if y .stop_gradient is False :
115+ if ( y .stop_gradient is False ) :
116116 assert np .array_equal (
117117 out [3 ], self .ref_y_backward (out [1 ]))
118- elif y .stop_gradient is False :
118+ elif ( y .stop_gradient is False ) :
119119 assert np .array_equal (out [2 ],
120120 self .ref_y_backward (out [1 ]))
121121
@@ -124,54 +124,46 @@ def test_api_broadcast(self, use_cuda=False):
124124 with fluid .program_guard (main_program ):
125125 x = fluid .layers .data (name = 'x' , shape = [4 , 1 ], dtype = 'float32' )
126126 y = fluid .layers .data (name = 'y' , shape = [4 , 2 ], dtype = 'float32' )
127- x_i = np .array ([[0.9383 , 0.1983 , 3.2 , 1.2 ]]).astype ("float32" )
128- y_i = np .array ([[1.0 , 1.0 , 1.0 , 1.0 ],
129- [1.0 , 1.0 , 1.0 , 1.0 ]]).astype ("float32" )
130- result = paddle .where (x > 1 , x = x , y = y )
131-
127+ x_i = np .array ([[0.9383 , 0.1983 , 3.2 , 1.2 ]]).astype ('float32' )
128+ y_i = np .array (
129+ [[1.0 , 1.0 , 1.0 , 1.0 ], [1.0 , 1.0 , 1.0 , 1.0 ]]).astype ('float32' )
130+ result = paddle .where ((x > 1 ), x = x , y = y )
132131 for use_cuda in [False , True ]:
133- if use_cuda and not fluid .core .is_compiled_with_cuda ():
132+ if ( use_cuda and ( not fluid .core .is_compiled_with_cuda ()) ):
134133 return
135- place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
134+ place = ( fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace () )
136135 exe = fluid .Executor (place )
137136 out = exe .run (fluid .default_main_program (),
138137 feed = {'x' : x_i ,
139138 'y' : y_i },
140139 fetch_list = [result ])
141- assert np .array_equal (out [0 ], np .where (x_i > 1 , x_i , y_i ))
140+ assert np .array_equal (out [0 ], np .where (( x_i > 1 ) , x_i , y_i ))
142141
143142 def __test_where_with_broadcast_static (self , cond_shape , x_shape , y_shape ):
144143 paddle .enable_static ()
145-
146144 main_program = Program ()
147145 with fluid .program_guard (main_program ):
148146 cond = fluid .layers .data (
149147 name = 'cond' , shape = cond_shape , dtype = 'bool' )
150148 x = fluid .layers .data (name = 'x' , shape = x_shape , dtype = 'float32' )
151149 y = fluid .layers .data (name = 'y' , shape = y_shape , dtype = 'float32' )
152-
153- cond_data_tmp = np .random .random (size = cond_shape ).astype ("float32" )
154- cond_data = cond_data_tmp < 0.3
155- x_data = np .random .random (size = x_shape ).astype ("float32" )
156- y_data = np .random .random (size = y_shape ).astype ("float32" )
157-
150+ cond_data_tmp = np .random .random (size = cond_shape ).astype ('float32' )
151+ cond_data = (cond_data_tmp < 0.3 )
152+ x_data = np .random .random (size = x_shape ).astype ('float32' )
153+ y_data = np .random .random (size = y_shape ).astype ('float32' )
158154 result = paddle .where (condition = cond , x = x , y = y )
159-
160155 for use_cuda in [False , True ]:
161- if use_cuda and not fluid .core .is_compiled_with_cuda ():
156+ if ( use_cuda and ( not fluid .core .is_compiled_with_cuda ()) ):
162157 return
163- place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
164-
158+ place = (fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ())
165159 exe = fluid .Executor (place )
166160 out = exe .run (
167161 fluid .default_main_program (),
168162 feed = {'cond' : cond_data ,
169163 'x' : x_data ,
170164 'y' : y_data },
171165 fetch_list = [result ])
172-
173166 expect = np .where (cond_data , x_data , y_data )
174-
175167 assert np .array_equal (out [0 ], expect )
176168
177169 def test_static_api_broadcast_1 (self ):
@@ -198,28 +190,24 @@ def test_static_api_broadcast_4(self):
198190 b_shape = [2 , 2 , 4 ]
199191 self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
200192
201- # @Note Now, maybe not compatibility with old version
202193 def test_static_api_broadcast_5 (self ):
203194 cond_shape = [3 , 2 , 2 , 4 ]
204195 a_shape = [2 , 2 , 4 ]
205196 b_shape = [2 , 2 , 4 ]
206197 self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
207198
208- # @Note Now, maybe not compatibility with old version
209199 def test_static_api_broadcast_6 (self ):
210200 cond_shape = [2 , 2 , 4 ]
211201 a_shape = [2 , 2 , 1 ]
212202 b_shape = [2 , 2 , 1 ]
213203 self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
214204
215- # @Note Now, maybe not compatibility with old version
216205 def test_static_api_broadcast_7 (self ):
217206 cond_shape = [2 , 2 , 4 ]
218207 a_shape = [2 , 1 , 4 ]
219208 b_shape = [2 , 1 , 4 ]
220209 self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
221210
222- # @Note Now, maybe not compatibility with old version
223211 def test_static_api_broadcast_8 (self ):
224212 cond_shape = [3 , 2 , 2 , 4 ]
225213 a_shape = [2 , 2 , 1 ]
@@ -230,9 +218,9 @@ def test_static_api_broadcast_8(self):
230218class TestWhereDygraphAPI (unittest .TestCase ):
231219 def test_api (self ):
232220 with fluid .dygraph .guard ():
233- x_i = np .array ([0.9383 , 0.1983 , 3.2 , 1.2 ]).astype (" float64" )
234- y_i = np .array ([1.0 , 1.0 , 1.0 , 1.0 ]).astype (" float64" )
235- cond_i = np .array ([False , False , True , True ]).astype (" bool" )
221+ x_i = np .array ([0.9383 , 0.1983 , 3.2 , 1.2 ]).astype (' float64' )
222+ y_i = np .array ([1.0 , 1.0 , 1.0 , 1.0 ]).astype (' float64' )
223+ cond_i = np .array ([False , False , True , True ]).astype (' bool' )
236224 x = fluid .dygraph .to_variable (x_i )
237225 y = fluid .dygraph .to_variable (y_i )
238226 cond = fluid .dygraph .to_variable (cond_i )
@@ -242,15 +230,12 @@ def test_api(self):
242230 def __test_where_with_broadcast_dygraph (self , cond_shape , a_shape , b_shape ):
243231 with fluid .dygraph .guard ():
244232 cond_tmp = paddle .rand (cond_shape )
245- cond = cond_tmp < 0.3
233+ cond = ( cond_tmp < 0.3 )
246234 a = paddle .rand (a_shape )
247235 b = paddle .rand (b_shape )
248-
249236 result = paddle .where (cond , a , b )
250237 result = result .numpy ()
251-
252238 expect = np .where (cond , a , b )
253-
254239 self .assertTrue (np .array_equal (expect , result ))
255240
256241 def test_dygraph_api_broadcast_1 (self ):
@@ -277,28 +262,24 @@ def test_dygraph_api_broadcast_4(self):
277262 b_shape = [2 , 2 , 4 ]
278263 self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
279264
280- # @Note Now, maybe not compatibility with old version
281265 def test_dygraph_api_broadcast_5 (self ):
282266 cond_shape = [3 , 2 , 2 , 4 ]
283267 a_shape = [2 , 2 , 4 ]
284268 b_shape = [2 , 2 , 4 ]
285269 self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
286270
287- # @Note Now, maybe not compatibility with old version
288271 def test_dygraph_api_broadcast_6 (self ):
289272 cond_shape = [2 , 2 , 4 ]
290273 a_shape = [2 , 2 , 1 ]
291274 b_shape = [2 , 2 , 1 ]
292275 self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
293276
294- # @Note Now, maybe not compatibility with old version
295277 def test_dygraph_api_broadcast_7 (self ):
296278 cond_shape = [2 , 2 , 4 ]
297279 a_shape = [2 , 1 , 4 ]
298280 b_shape = [2 , 1 , 4 ]
299281 self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
300282
301- # @Note Now, maybe not compatibility with old version
302283 def test_dygraph_api_broadcast_8 (self ):
303284 cond_shape = [3 , 2 , 2 , 4 ]
304285 a_shape = [2 , 2 , 1 ]
@@ -308,40 +289,50 @@ def test_dygraph_api_broadcast_8(self):
308289 def test_where_condition (self ):
309290 data = np .array ([[True , False ], [False , True ]])
310291 with program_guard (Program (), Program ()):
311- x = fluid .layers .data (name = 'x' , shape = [- 1 , 2 ])
292+ x = fluid .layers .data (name = 'x' , shape = [( - 1 ) , 2 ])
312293 y = paddle .where (x )
313294 self .assertEqual (type (y ), tuple )
314295 self .assertEqual (len (y ), 2 )
315296 z = fluid .layers .concat (list (y ), axis = 1 )
316297 exe = fluid .Executor (fluid .CPUPlace ())
317-
318- res , = exe .run (feed = {'x' : data },
319- fetch_list = [z .name ],
320- return_numpy = False )
298+ (res , ) = exe .run (feed = {'x' : data },
299+ fetch_list = [z .name ],
300+ return_numpy = False )
321301 expect_out = np .array ([[0 , 0 ], [1 , 1 ]])
322302 self .assertTrue (np .allclose (expect_out , np .array (res )))
323-
324303 data = np .array ([True , True , False ])
325304 with program_guard (Program (), Program ()):
326- x = fluid .layers .data (name = 'x' , shape = [- 1 ])
305+ x = fluid .layers .data (name = 'x' , shape = [( - 1 ) ])
327306 y = paddle .where (x )
328307 self .assertEqual (type (y ), tuple )
329308 self .assertEqual (len (y ), 1 )
330309 z = fluid .layers .concat (list (y ), axis = 1 )
331310 exe = fluid .Executor (fluid .CPUPlace ())
332- res , = exe .run (feed = {'x' : data },
333- fetch_list = [z .name ],
334- return_numpy = False )
311+ ( res , ) = exe .run (feed = {'x' : data },
312+ fetch_list = [z .name ],
313+ return_numpy = False )
335314 expect_out = np .array ([[0 ], [1 ]])
336315 self .assertTrue (np .allclose (expect_out , np .array (res )))
337316
317+ def test_eager (self ):
318+ with _test_eager_guard ():
319+ self .test_api ()
320+ self .test_dygraph_api_broadcast_1 ()
321+ self .test_dygraph_api_broadcast_2 ()
322+ self .test_dygraph_api_broadcast_3 ()
323+ self .test_dygraph_api_broadcast_4 ()
324+ self .test_dygraph_api_broadcast_5 ()
325+ self .test_dygraph_api_broadcast_6 ()
326+ self .test_dygraph_api_broadcast_7 ()
327+ self .test_dygraph_api_broadcast_8 ()
328+
338329
339330class TestWhereOpError (unittest .TestCase ):
340331 def test_errors (self ):
341332 with program_guard (Program (), Program ()):
342- x_i = np .array ([0.9383 , 0.1983 , 3.2 , 1.2 ]).astype (" float64" )
343- y_i = np .array ([1.0 , 1.0 , 1.0 , 1.0 ]).astype (" float64" )
344- cond_i = np .array ([False , False , True , True ]).astype (" bool" )
333+ x_i = np .array ([0.9383 , 0.1983 , 3.2 , 1.2 ]).astype (' float64' )
334+ y_i = np .array ([1.0 , 1.0 , 1.0 , 1.0 ]).astype (' float64' )
335+ cond_i = np .array ([False , False , True , True ]).astype (' bool' )
345336
346337 def test_Variable ():
347338 paddle .where (cond_i , x_i , y_i )
@@ -360,10 +351,14 @@ def test_value_error(self):
360351 with fluid .dygraph .guard ():
361352 cond_shape = [2 , 2 , 4 ]
362353 cond_tmp = paddle .rand (cond_shape )
363- cond = cond_tmp < 0.3
354+ cond = ( cond_tmp < 0.3 )
364355 a = paddle .rand (cond_shape )
365356 self .assertRaises (ValueError , paddle .where , cond , a )
366357
358+ def test_eager (self ):
359+ with _test_eager_guard ():
360+ self .test_value_error ()
361+
367362
368- if __name__ == '__main__' :
363+ if ( __name__ == '__main__' ) :
369364 unittest .main ()
0 commit comments