Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,16 +763,16 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
if shape is not None:
check_shape(shape, 'normal')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

744 行改回 if not in_dynamic_mode(): 可以么?这样 PIR 下少了很多检查,当然 Variable 相关位置需要加一下 Value


if isinstance(mean, Variable):
if isinstance(std, Variable):
if isinstance(mean, (Variable, paddle.pir.Value)):
if isinstance(std, (Variable, paddle.pir.Value)):
if std.dtype != mean.dtype:
std = paddle.cast(std, mean.dtype)
mean_shape = paddle.shape(mean)
std = paddle.reshape(std, mean_shape)
else:
std = float(std)
out = standard_normal(paddle.shape(mean), mean.dtype, name)
elif isinstance(std, Variable):
elif isinstance(std, (Variable, paddle.pir.Value)):
mean = float(mean)
out = standard_normal(paddle.shape(std), std.dtype, name)
else:
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_zero_dim_no_backward_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def test_arange(self):
)[0]
np.testing.assert_array_equal(res, [1.0, 2.0, 3.0, 4.0, 5.0])

@test_with_pir_api
def test_normal(self):
mean = paddle.full([], 0.0)
std = paddle.full([], 0.0)
Expand Down