Skip to content

Commit 609b50a

Browse files
authored
[Eager] polish some api logic (#49717)
* [Eager] polish some api logic * fix split * revover
1 parent 0b24d16 commit 609b50a

3 files changed

Lines changed: 12 additions & 22 deletions

File tree

python/paddle/nn/layer/norm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,8 @@ def forward(self, input):
999999
self._use_global_stats,
10001000
self._trainable_statistics,
10011001
)
1002+
if self._act is None:
1003+
return batch_norm_out
10021004
return dygraph_utils._append_activation_in_dygraph(
10031005
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn
10041006
)

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,31 +1916,20 @@ def split(x, num_or_sections, axis=0, name=None):
19161916
input = x
19171917
dim = axis
19181918
if in_dygraph_mode():
1919-
num = None
1920-
attrs = ()
1921-
19221919
if isinstance(dim, Variable):
19231920
dim = dim.numpy()
19241921
dim = dim.item(0)
19251922
assert len(input.shape) + dim >= 0, "(rank(x) + axis) must >= 0"
19261923
dim = (len(input.shape) + dim) if dim < 0 else dim
1927-
attrs += ('axis', dim)
19281924

1929-
if isinstance(num_or_sections, int):
1930-
num = num_or_sections
1931-
attrs += ('num', num_or_sections)
1932-
elif isinstance(num_or_sections, (list, tuple)):
1933-
num = len(num_or_sections)
1925+
if isinstance(num_or_sections, (list, tuple)):
19341926
if utils._contain_var(num_or_sections):
19351927
for index, item in enumerate(num_or_sections):
19361928
if isinstance(item, Variable):
19371929
num_or_sections[index] = num_or_sections[index].numpy()[
19381930
0
19391931
]
1940-
attrs += ('sections', list(num_or_sections))
1941-
else:
1942-
attrs += ('sections', list(num_or_sections))
1943-
else:
1932+
elif not isinstance(num_or_sections, int):
19441933
raise TypeError(
19451934
"The type of 'num_or_sections' in split must be int, list or tuple in imperative mode, but "
19461935
"received %s." % (type(num_or_sections))

python/paddle/tensor/search.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -612,15 +612,6 @@ def where(condition, x=None, y=None, name=None):
612612
if x is None or y is None:
613613
raise ValueError("either both or neither of x and y should be given")
614614

615-
if not paddle.in_dynamic_mode():
616-
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
617-
check_variable_and_dtype(
618-
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where'
619-
)
620-
check_variable_and_dtype(
621-
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where'
622-
)
623-
624615
condition_shape = list(condition.shape)
625616
x_shape = list(x.shape)
626617
y_shape = list(y.shape)
@@ -646,6 +637,14 @@ def where(condition, x=None, y=None, name=None):
646637
if in_dygraph_mode():
647638
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
648639
else:
640+
check_variable_and_dtype(condition, 'condition', ['bool'], 'where')
641+
check_variable_and_dtype(
642+
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'where'
643+
)
644+
check_variable_and_dtype(
645+
y, 'y', ['float32', 'float64', 'int32', 'int64'], 'where'
646+
)
647+
649648
helper = LayerHelper("where", **locals())
650649
out = helper.create_variable_for_type_inference(dtype=x.dtype)
651650

0 commit comments

Comments
 (0)