Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,18 +741,22 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
[0.48646951, 0.00815189, 3.74022293])
>>> # doctest: -SKIP
"""
if not in_dynamic_or_pir_mode():
check_type(mean, 'mean', (int, float, Variable), 'normal')
check_type(std, 'std', (int, float, Variable), 'normal')
if isinstance(mean, Variable):
if not in_dynamic_mode():
check_type(
mean, 'mean', (int, float, Variable, paddle.pir.Value), 'normal'
)
check_type(
std, 'std', (int, float, Variable, paddle.pir.Value), 'normal'
)
if isinstance(mean, (Variable, paddle.pir.Value)):
check_dtype(
mean.dtype,
'mean',
['float32', 'float64'],
'normal',
"If mean is Tensor, it's data type only support float32, float64.",
)
if isinstance(std, Variable):
if isinstance(std, (Variable, paddle.pir.Value)):
check_dtype(
std.dtype,
'std',
Expand All @@ -763,16 +767,16 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
if shape is not None:
check_shape(shape, 'normal')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

744 行改回 if not in_dynamic_mode(): 可以么?这样 PIR 下少了很多检查,当然 Variable 相关位置需要加一下 Value


if isinstance(mean, Variable):
if isinstance(std, Variable):
if isinstance(mean, (Variable, paddle.pir.Value)):
if isinstance(std, (Variable, paddle.pir.Value)):
if std.dtype != mean.dtype:
std = paddle.cast(std, mean.dtype)
mean_shape = paddle.shape(mean)
std = paddle.reshape(std, mean_shape)
else:
std = float(std)
out = standard_normal(paddle.shape(mean), mean.dtype, name)
elif isinstance(std, Variable):
elif isinstance(std, (Variable, paddle.pir.Value)):
mean = float(mean)
out = standard_normal(paddle.shape(std), std.dtype, name)
else:
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ set_tests_properties(test_imperative_optimizer_v2 PROPERTIES TIMEOUT 250)
set_tests_properties(test_pool2d_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_transpose_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_activation_op PROPERTIES TIMEOUT 270)
set_tests_properties(test_normal PROPERTIES TIMEOUT 120)
set_tests_properties(test_normal1 PROPERTIES TIMEOUT 120)
set_tests_properties(test_bilinear_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_decoupled_py_reader PROPERTIES TIMEOUT 120)
set_tests_properties(test_fuse_bn_act_pass PROPERTIES TIMEOUT 120)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

import paddle
from paddle.pir_utils import test_with_pir_api

np.random.seed(10)
paddle.seed(10)
Expand Down Expand Up @@ -62,10 +63,11 @@ def static_api(self):
ret_all_shape = copy.deepcopy(shape)
ret_all_shape.insert(0, self.repeat_num)
ret_all = np.zeros(ret_all_shape, self.dtype)
main_program = paddle.static.Program()
if isinstance(self.mean, np.ndarray) and isinstance(
self.std, np.ndarray
):
with paddle.static.program_guard(paddle.static.Program()):
with paddle.static.program_guard(main_program):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥改成同一个 Program 了啊

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只可能跑到一种

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

喔,不过人家原来也没错啊,为啥要改

Copy link
Member Author

@gouzil gouzil Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我记得这玩意之前有点问题来着,不会切换还是啥的

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean = paddle.static.data(
'Mean', self.mean.shape, self.mean.dtype
)
Expand All @@ -84,7 +86,7 @@ def static_api(self):
ret_all[i] = ret[0]
return ret_all
elif isinstance(self.mean, np.ndarray):
with paddle.static.program_guard(paddle.static.Program()):
with paddle.static.program_guard(main_program):
mean = paddle.static.data(
'Mean', self.mean.shape, self.mean.dtype
)
Expand All @@ -96,7 +98,7 @@ def static_api(self):
ret_all[i] = ret[0]
return ret_all
elif isinstance(self.std, np.ndarray):
with paddle.static.program_guard(paddle.static.Program()):
with paddle.static.program_guard(main_program):
std = paddle.static.data('Std', self.std.shape, self.std.dtype)
out = paddle.normal(self.mean, std, self.shape)

Expand All @@ -106,7 +108,7 @@ def static_api(self):
ret_all[i] = ret[0]
return ret_all
else:
with paddle.static.program_guard(paddle.static.Program()):
with paddle.static.program_guard(main_program):
out = paddle.normal(self.mean, self.std, self.shape)

exe = paddle.static.Executor(self.place)
Expand Down Expand Up @@ -138,6 +140,7 @@ def dygraph_api(self):
paddle.enable_static()
return ret_all

@test_with_pir_api
def test_api(self):
ret_static = self.static_api()
ret_dygraph = self.dygraph_api()
Expand Down Expand Up @@ -185,6 +188,7 @@ def set_attrs(self):


class TestNormalAlias(unittest.TestCase):
@test_with_pir_api
def test_alias(self):
paddle.disable_static()
shape = [1, 2, 3]
Expand All @@ -195,8 +199,10 @@ def test_alias(self):


class TestNormalErrors(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里现在是不能加嘛?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补上了

with paddle.static.program_guard(paddle.static.Program()):
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program):
mean = [1, 2, 3]
self.assertRaises(TypeError, paddle.normal, mean)

Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_zero_dim_no_backward_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def test_arange(self):
)[0]
np.testing.assert_array_equal(res, [1.0, 2.0, 3.0, 4.0, 5.0])

@test_with_pir_api
def test_normal(self):
mean = paddle.full([], 0.0)
std = paddle.full([], 0.0)
Expand Down
4 changes: 2 additions & 2 deletions tools/parallel_UT_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@
'test_elementwise_sub_op',
'test_compare_op',
'test_simnet',
'test_normal',
'test_normal1',
'test_tensor_scalar_type_promotion_static',
'test_trt_group_norm_op',
'test_learning_rate_scheduler',
Expand Down Expand Up @@ -2695,7 +2695,7 @@
'test_grid_sample_function',
'test_huber_loss_op',
'test_one_hot_op',
'test_normal',
'test_normal1',
'test_imperative_auto_prune',
'test_nn_grad',
'test_nearest_interp_op',
Expand Down
2 changes: 1 addition & 1 deletion tools/static_mode_white_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@
'test_norm_all',
'test_norm_nn_grad',
'test_norm_op',
'test_normal',
'test_normal1',
'test_normalization_wrapper',
'test_npair_loss_op',
'test_numel_op',
Expand Down
2 changes: 1 addition & 1 deletion tools/windows/run_unittests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ long_time_test="^test_gru_op$|\
^test_nearest_interp_v2_op$|\
^test_nn_grad$|\
^test_norm_nn_grad$|\
^test_normal$|\
^test_normal1$|\
^test_pool3d_op$|\
^test_static_save_load$|\
^test_trilinear_interp_op$|\
Expand Down