@@ -140,6 +140,92 @@ def test_api_broadcast(self, use_cuda=False):
140140 fetch_list = [result ])
141141 assert np .array_equal (out [0 ], np .where (x_i > 1 , x_i , y_i ))
142142
143+ def __test_where_with_broadcast_static (self , cond_shape , x_shape , y_shape ):
144+ paddle .enable_static ()
145+
146+ main_program = Program ()
147+ with fluid .program_guard (main_program ):
148+ cond = fluid .layers .data (
149+ name = 'cond' , shape = cond_shape , dtype = 'bool' )
150+ x = fluid .layers .data (name = 'x' , shape = x_shape , dtype = 'float32' )
151+ y = fluid .layers .data (name = 'y' , shape = y_shape , dtype = 'float32' )
152+
153+ cond_data_tmp = np .random .random (size = cond_shape ).astype ("float32" )
154+ cond_data = cond_data_tmp < 0.3
155+ x_data = np .random .random (size = x_shape ).astype ("float32" )
156+ y_data = np .random .random (size = y_shape ).astype ("float32" )
157+
158+ result = paddle .where (condition = cond , x = x , y = y )
159+
160+ for use_cuda in [False , True ]:
161+ if use_cuda and not fluid .core .is_compiled_with_cuda ():
162+ return
163+ place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
164+
165+ exe = fluid .Executor (place )
166+ out = exe .run (
167+ fluid .default_main_program (),
168+ feed = {'cond' : cond_data ,
169+ 'x' : x_data ,
170+ 'y' : y_data },
171+ fetch_list = [result ])
172+
173+ expect = np .where (cond_data , x_data , y_data )
174+
175+ assert np .array_equal (out [0 ], expect )
176+
177+ def test_static_api_broadcast_1 (self ):
178+ cond_shape = [2 , 4 ]
179+ a_shape = [2 , 2 , 4 ]
180+ b_shape = [2 , 2 , 4 ]
181+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
182+
183+ def test_static_api_broadcast_2 (self ):
184+ cond_shape = [2 , 1 ]
185+ a_shape = [2 , 2 , 4 ]
186+ b_shape = [2 , 2 , 4 ]
187+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
188+
189+ def test_static_api_broadcast_3 (self ):
190+ cond_shape = [2 , 2 , 1 ]
191+ a_shape = [2 , 2 , 4 ]
192+ b_shape = [2 , 2 , 4 ]
193+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
194+
195+ def test_static_api_broadcast_4 (self ):
196+ cond_shape = [2 , 1 , 4 ]
197+ a_shape = [2 , 2 , 4 ]
198+ b_shape = [2 , 2 , 4 ]
199+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
200+
201+ # @Note Now, maybe not compatibility with old version
202+ def test_static_api_broadcast_5 (self ):
203+ cond_shape = [3 , 2 , 2 , 4 ]
204+ a_shape = [2 , 2 , 4 ]
205+ b_shape = [2 , 2 , 4 ]
206+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
207+
208+ # @Note Now, maybe not compatibility with old version
209+ def test_static_api_broadcast_6 (self ):
210+ cond_shape = [2 , 2 , 4 ]
211+ a_shape = [2 , 2 , 1 ]
212+ b_shape = [2 , 2 , 1 ]
213+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
214+
215+ # @Note Now, maybe not compatibility with old version
216+ def test_static_api_broadcast_7 (self ):
217+ cond_shape = [2 , 2 , 4 ]
218+ a_shape = [2 , 1 , 4 ]
219+ b_shape = [2 , 1 , 4 ]
220+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
221+
222+ # @Note Now, maybe not compatibility with old version
223+ def test_static_api_broadcast_8 (self ):
224+ cond_shape = [3 , 2 , 2 , 4 ]
225+ a_shape = [2 , 2 , 1 ]
226+ b_shape = [2 , 2 , 1 ]
227+ self .__test_where_with_broadcast_static (cond_shape , a_shape , b_shape )
228+
143229
144230class TestWhereDygraphAPI (unittest .TestCase ):
145231 def test_api (self ):
@@ -153,6 +239,72 @@ def test_api(self):
153239 out = paddle .where (cond , x , y )
154240 assert np .array_equal (out .numpy (), np .where (cond_i , x_i , y_i ))
155241
242+ def __test_where_with_broadcast_dygraph (self , cond_shape , a_shape , b_shape ):
243+ with fluid .dygraph .guard ():
244+ cond_tmp = paddle .rand (cond_shape )
245+ cond = cond_tmp < 0.3
246+ a = paddle .rand (a_shape )
247+ b = paddle .rand (b_shape )
248+
249+ result = paddle .where (cond , a , b )
250+ result = result .numpy ()
251+
252+ expect = np .where (cond , a , b )
253+
254+ self .assertTrue (np .array_equal (expect , result ))
255+
256+ def test_dygraph_api_broadcast_1 (self ):
257+ cond_shape = [2 , 4 ]
258+ a_shape = [2 , 2 , 4 ]
259+ b_shape = [2 , 2 , 4 ]
260+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
261+
262+ def test_dygraph_api_broadcast_2 (self ):
263+ cond_shape = [2 , 1 ]
264+ a_shape = [2 , 2 , 4 ]
265+ b_shape = [2 , 2 , 4 ]
266+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
267+
268+ def test_dygraph_api_broadcast_3 (self ):
269+ cond_shape = [2 , 2 , 1 ]
270+ a_shape = [2 , 2 , 4 ]
271+ b_shape = [2 , 2 , 4 ]
272+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
273+
274+ def test_dygraph_api_broadcast_4 (self ):
275+ cond_shape = [2 , 1 , 4 ]
276+ a_shape = [2 , 2 , 4 ]
277+ b_shape = [2 , 2 , 4 ]
278+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
279+
280+ # @Note Now, maybe not compatibility with old version
281+ def test_dygraph_api_broadcast_5 (self ):
282+ cond_shape = [3 , 2 , 2 , 4 ]
283+ a_shape = [2 , 2 , 4 ]
284+ b_shape = [2 , 2 , 4 ]
285+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
286+
287+ # @Note Now, maybe not compatibility with old version
288+ def test_dygraph_api_broadcast_6 (self ):
289+ cond_shape = [2 , 2 , 4 ]
290+ a_shape = [2 , 2 , 1 ]
291+ b_shape = [2 , 2 , 1 ]
292+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
293+
294+ # @Note Now, maybe not compatibility with old version
295+ def test_dygraph_api_broadcast_7 (self ):
296+ cond_shape = [2 , 2 , 4 ]
297+ a_shape = [2 , 1 , 4 ]
298+ b_shape = [2 , 1 , 4 ]
299+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
300+
301+ # @Note Now, maybe not compatibility with old version
302+ def test_dygraph_api_broadcast_8 (self ):
303+ cond_shape = [3 , 2 , 2 , 4 ]
304+ a_shape = [2 , 2 , 1 ]
305+ b_shape = [2 , 2 , 1 ]
306+ self .__test_where_with_broadcast_dygraph (cond_shape , a_shape , b_shape )
307+
156308
157309class TestWhereOpError (unittest .TestCase ):
158310 def test_errors (self ):
0 commit comments