Skip to content

Commit ffc3d36

Browse files
authored
[WIP] paddle.where api add broadcast, when x_shape == y_shape, and x_shape != cond_shape (#35092)
* where op add broadcast, when x_shape == y_shape, and x_shape != cond_shape * add static api tests, and delete debug codes
1 parent e877248 commit ffc3d36

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed

python/paddle/fluid/tests/unittests/test_where_op.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,92 @@ def test_api_broadcast(self, use_cuda=False):
140140
fetch_list=[result])
141141
assert np.array_equal(out[0], np.where(x_i > 1, x_i, y_i))
142142

143+
def __test_where_with_broadcast_static(self, cond_shape, x_shape, y_shape):
144+
paddle.enable_static()
145+
146+
main_program = Program()
147+
with fluid.program_guard(main_program):
148+
cond = fluid.layers.data(
149+
name='cond', shape=cond_shape, dtype='bool')
150+
x = fluid.layers.data(name='x', shape=x_shape, dtype='float32')
151+
y = fluid.layers.data(name='y', shape=y_shape, dtype='float32')
152+
153+
cond_data_tmp = np.random.random(size=cond_shape).astype("float32")
154+
cond_data = cond_data_tmp < 0.3
155+
x_data = np.random.random(size=x_shape).astype("float32")
156+
y_data = np.random.random(size=y_shape).astype("float32")
157+
158+
result = paddle.where(condition=cond, x=x, y=y)
159+
160+
for use_cuda in [False, True]:
161+
if use_cuda and not fluid.core.is_compiled_with_cuda():
162+
return
163+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
164+
165+
exe = fluid.Executor(place)
166+
out = exe.run(
167+
fluid.default_main_program(),
168+
feed={'cond': cond_data,
169+
'x': x_data,
170+
'y': y_data},
171+
fetch_list=[result])
172+
173+
expect = np.where(cond_data, x_data, y_data)
174+
175+
assert np.array_equal(out[0], expect)
176+
177+
def test_static_api_broadcast_1(self):
178+
cond_shape = [2, 4]
179+
a_shape = [2, 2, 4]
180+
b_shape = [2, 2, 4]
181+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
182+
183+
def test_static_api_broadcast_2(self):
184+
cond_shape = [2, 1]
185+
a_shape = [2, 2, 4]
186+
b_shape = [2, 2, 4]
187+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
188+
189+
def test_static_api_broadcast_3(self):
190+
cond_shape = [2, 2, 1]
191+
a_shape = [2, 2, 4]
192+
b_shape = [2, 2, 4]
193+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
194+
195+
def test_static_api_broadcast_4(self):
196+
cond_shape = [2, 1, 4]
197+
a_shape = [2, 2, 4]
198+
b_shape = [2, 2, 4]
199+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
200+
201+
# @Note Now, maybe not compatibility with old version
202+
def test_static_api_broadcast_5(self):
203+
cond_shape = [3, 2, 2, 4]
204+
a_shape = [2, 2, 4]
205+
b_shape = [2, 2, 4]
206+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
207+
208+
# @Note Now, maybe not compatibility with old version
209+
def test_static_api_broadcast_6(self):
210+
cond_shape = [2, 2, 4]
211+
a_shape = [2, 2, 1]
212+
b_shape = [2, 2, 1]
213+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
214+
215+
# @Note Now, maybe not compatibility with old version
216+
def test_static_api_broadcast_7(self):
217+
cond_shape = [2, 2, 4]
218+
a_shape = [2, 1, 4]
219+
b_shape = [2, 1, 4]
220+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
221+
222+
# @Note Now, maybe not compatibility with old version
223+
def test_static_api_broadcast_8(self):
224+
cond_shape = [3, 2, 2, 4]
225+
a_shape = [2, 2, 1]
226+
b_shape = [2, 2, 1]
227+
self.__test_where_with_broadcast_static(cond_shape, a_shape, b_shape)
228+
143229

144230
class TestWhereDygraphAPI(unittest.TestCase):
145231
def test_api(self):
@@ -153,6 +239,72 @@ def test_api(self):
153239
out = paddle.where(cond, x, y)
154240
assert np.array_equal(out.numpy(), np.where(cond_i, x_i, y_i))
155241

242+
def __test_where_with_broadcast_dygraph(self, cond_shape, a_shape, b_shape):
243+
with fluid.dygraph.guard():
244+
cond_tmp = paddle.rand(cond_shape)
245+
cond = cond_tmp < 0.3
246+
a = paddle.rand(a_shape)
247+
b = paddle.rand(b_shape)
248+
249+
result = paddle.where(cond, a, b)
250+
result = result.numpy()
251+
252+
expect = np.where(cond, a, b)
253+
254+
self.assertTrue(np.array_equal(expect, result))
255+
256+
def test_dygraph_api_broadcast_1(self):
257+
cond_shape = [2, 4]
258+
a_shape = [2, 2, 4]
259+
b_shape = [2, 2, 4]
260+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
261+
262+
def test_dygraph_api_broadcast_2(self):
263+
cond_shape = [2, 1]
264+
a_shape = [2, 2, 4]
265+
b_shape = [2, 2, 4]
266+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
267+
268+
def test_dygraph_api_broadcast_3(self):
269+
cond_shape = [2, 2, 1]
270+
a_shape = [2, 2, 4]
271+
b_shape = [2, 2, 4]
272+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
273+
274+
def test_dygraph_api_broadcast_4(self):
275+
cond_shape = [2, 1, 4]
276+
a_shape = [2, 2, 4]
277+
b_shape = [2, 2, 4]
278+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
279+
280+
# @Note Now, maybe not compatibility with old version
281+
def test_dygraph_api_broadcast_5(self):
282+
cond_shape = [3, 2, 2, 4]
283+
a_shape = [2, 2, 4]
284+
b_shape = [2, 2, 4]
285+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
286+
287+
# @Note Now, maybe not compatibility with old version
288+
def test_dygraph_api_broadcast_6(self):
289+
cond_shape = [2, 2, 4]
290+
a_shape = [2, 2, 1]
291+
b_shape = [2, 2, 1]
292+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
293+
294+
# @Note Now, maybe not compatibility with old version
295+
def test_dygraph_api_broadcast_7(self):
296+
cond_shape = [2, 2, 4]
297+
a_shape = [2, 1, 4]
298+
b_shape = [2, 1, 4]
299+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
300+
301+
# @Note Now, maybe not compatibility with old version
302+
def test_dygraph_api_broadcast_8(self):
303+
cond_shape = [3, 2, 2, 4]
304+
a_shape = [2, 2, 1]
305+
b_shape = [2, 2, 1]
306+
self.__test_where_with_broadcast_dygraph(cond_shape, a_shape, b_shape)
307+
156308

157309
class TestWhereOpError(unittest.TestCase):
158310
def test_errors(self):

python/paddle/tensor/search.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,9 +514,10 @@ def where(condition, x, y, name=None):
514514
check_variable_and_dtype(
515515
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where')
516516

517+
condition_shape = list(condition.shape)
517518
x_shape = list(x.shape)
518519
y_shape = list(y.shape)
519-
if x_shape == y_shape:
520+
if x_shape == y_shape and condition_shape == x_shape:
520521
if in_dygraph_mode():
521522
return _C_ops.where(condition, x, y)
522523
else:

0 commit comments

Comments
 (0)