Skip to content

Commit 24aaa86

Browse files
committed
fix rand_like impl
1 parent ec32a29 commit 24aaa86

File tree

1 file changed

+5
-18
lines changed

1 file changed

+5
-18
lines changed

python/paddle/tensor/random.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,9 +1220,9 @@ def rand_like(
12201220
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
12211221
dtype = convert_np_dtype_to_dtype_(dtype)
12221222

1223-
tensor = uniform(
1224-
input.shape, dtype=dtype, min=0.0, max=1.0, name=name, place=device
1225-
)
1223+
tensor = paddle.rand(input.shape, dtype=dtype, name=name)
1224+
if device is not None:
1225+
tensor = tensor.to(device)
12261226
if requires_grad:
12271227
tensor.stop_gradient = False
12281228
return tensor
@@ -1440,8 +1440,6 @@ def uniform(
14401440
max: float = 1.0,
14411441
seed: int = 0,
14421442
name: str | None = None,
1443-
*,
1444-
place: PlaceLike | None = None,
14451443
) -> Tensor:
14461444
"""
14471445
Returns a Tensor filled with random values sampled from a uniform
@@ -1474,9 +1472,6 @@ def uniform(
14741472
time. Default is 0.
14751473
name(str|None, optional): Name for the operation (optional, default is None).
14761474
For more information, please refer to :ref:`api_guide_Name`.
1477-
place(PlaceLike|None, optional): The desired device of returned tensor.
1478-
if None, uses the current device for the default tensor type (see paddle.device.set_device()).
1479-
device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. Default: None.
14801475
14811476
Returns:
14821477
Tensor, A Tensor filled with random values sampled from a uniform
@@ -1541,11 +1536,7 @@ def uniform(
15411536
float(min),
15421537
float(max),
15431538
seed,
1544-
(
1545-
_get_paddle_place(place)
1546-
if place is not None
1547-
else _current_expected_place()
1548-
),
1539+
_current_expected_place(),
15491540
)
15501541
elif in_pir_mode():
15511542
check_type(
@@ -1567,11 +1558,7 @@ def uniform(
15671558
min,
15681559
max,
15691560
seed,
1570-
(
1571-
_get_paddle_place(place)
1572-
if place is not None
1573-
else _current_expected_place()
1574-
),
1561+
_current_expected_place(),
15751562
)
15761563
else:
15771564
check_type(shape, 'shape', (list, tuple, Variable), 'uniform/rand')

0 commit comments

Comments
 (0)