Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def new_init(self, *args, **kwargs):
normal_,
poisson,
rand,
rand_like,
randint,
randint_like,
randn,
Expand Down Expand Up @@ -1246,6 +1247,7 @@ def __dir__(self):
'geometric_',
'randn',
'randn_like',
'rand_like',
'strided_slice',
'unique',
'unique_consecutive',
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@
normal_,
poisson,
rand,
rand_like,
randint,
randint_like,
randn,
Expand Down
131 changes: 123 additions & 8 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ def multinomial(
num_samples: int = 1,
replacement: bool = False,
name: str | None = None,
*,
out: Tensor | None = None,
) -> Tensor:
"""
Returns a Tensor filled with random values sampled from a Multinomial
Expand All @@ -474,6 +476,7 @@ def multinomial(
name(str|None, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
out (Tensor|None, optional): The output Tensor. If set, the result will be stored in this Tensor. Default is None.
Returns:
Tensor, A Tensor filled with sampled category index after ``num_samples`` times samples.

Expand Down Expand Up @@ -516,7 +519,7 @@ def multinomial(
"""

if in_dynamic_or_pir_mode():
return _C_ops.multinomial(x, num_samples, replacement)
return _C_ops.multinomial(x, num_samples, replacement, out=out)
else:
check_variable_and_dtype(
x, "x", ["uint16", "float16", "float32", "float64"], "multinomial"
Expand Down Expand Up @@ -1150,14 +1153,104 @@ def randn_like(
"""
if dtype is None:
dtype = x.dtype
else:
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)
shape = paddle.shape(x)

return standard_normal(shape, dtype, name)


def rand_like(
input,
name: str | None = None,
*,
dtype: DTypeLike | None = None,
device: PlaceLike | None = None,
requires_grad: bool = False,
):
"""
Returns a tensor with the same size as input that is filled with random numbers from a uniform distribution on the interval [0, 1).

Args:
input (Tensor): The input multi-dimensional tensor which specifies shape. The dtype of ``input``
can be float16, float64, float8_e4m3fn, float32, bfloat16.
name (str|None, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
dtype (str|np.dtype|paddle.dtype|None, optional): The data type of the
output tensor. Supported data types: float16, float64, float8_e4m3fn, float32, bfloat16.
If ``dtype`` is None, the data type is the same as input's data type. Default is None.
device (str|paddle.Place|None, optional): The device on which to place the created tensor.
If None, the device is the same as input's device. Default is None.
requires_grad (bool, optional): Whether to compute gradients for the created tensor.
Default is False.

Returns:
Tensor: A Tensor with the same size as input that is filled with random numbers from a uniform distribution on the interval [0, 1).

Examples:
.. code-block:: python

>>> import paddle

>>> # example 1:
>>> # dtype is None and the dtype of input is float32
>>> x = paddle.zeros((2, 3)).astype("float32")
>>> out1 = paddle.rand_like(x)
>>> print(out1)
>>> # doctest: +SKIP("Random output")
Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
[[0.34962332, 0.82356787, 0.91275704],
[0.12328923, 0.58439839, 0.32735515]])
>>> # doctest: -SKIP
>>> print(out1.dtype)
paddle.float32

>>> # example 2:
>>> # dtype is None and the dtype of input is float64
>>> x = paddle.zeros((2, 3)).astype("float64")
>>> out2 = paddle.rand_like(x)
>>> print(out2)
>>> # doctest: +SKIP("Random output")
Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
[[0.73964721, 0.28413662, 0.91918457],
[0.62838351, 0.39185921, 0.51561823]])
>>> # doctest: -SKIP
>>> print(out2.dtype)
paddle.float64

>>> # example 3:
>>> # dtype is float64 and the dtype of input is float32
>>> x = paddle.zeros((2, 3)).astype("float32")
>>> out3 = paddle.rand_like(x, dtype="float64")
>>> print(out3)
>>> # doctest: +SKIP("Random output")
Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
[[0.84492219, 0.11572551, 0.73868765],
[0.90269387, 0.45644298, 0.28739912]])
>>> # doctest: -SKIP
>>> print(out3.dtype)
paddle.float64

>>> # example 4:
>>> # with requires_grad=True
>>> x = paddle.zeros((2, 2)).astype("float32")
>>> out4 = paddle.rand_like(x, requires_grad=True)
>>> print(out4.stop_gradient)
False
"""
if dtype is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这几个参数,randn全部都有,无需再特殊处理

infoflow 2025-08-27 21-44-17

Copy link
Contributor Author

@LLSGYN LLSGYN Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我们需要调用的是均匀分布的rand而不是randn,而rand并不支持这几个参数

dtype = input.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个无需要处理,后面的API会自行处理dtype=None的情况,取这个有点损耗

Copy link
Contributor

@zhwesky2010 zhwesky2010 Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只在最后一个API里处理dtype=None的情况,前面的API就直接将None往后面传,不用一直取None下的dtype。

如果其他的API不符合这个写法,做了冗余操作,帮忙一起处理下。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rand_like/randn_like必须要在前面处理dtype=None的情况,因为目标是获取和输入Tensor一致的dtype。


return uniform(
shape=input.shape,
dtype=dtype,
min=0.0,
max=1.0,
name=name,
device=device,
requires_grad=requires_grad,
)


def normal(
mean: complex | Tensor = 0.0,
std: float | Tensor = 1.0,
Expand Down Expand Up @@ -1370,6 +1463,10 @@ def uniform(
max: float = 1.0,
seed: int = 0,
name: str | None = None,
*,
out: Tensor | None = None,
device: PlaceLike | None = None,
requires_grad: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是还有一个pin_memory,参考下最新的rand代码

) -> Tensor:
"""
Returns a Tensor filled with random values sampled from a uniform
Expand Down Expand Up @@ -1460,14 +1557,23 @@ def uniform(

if in_dynamic_mode():
shape = paddle.utils.convert_shape_to_list(shape)
return _C_ops.uniform(
place = (
_current_expected_place()
if device is None
else _get_paddle_place(device)
)
tensor = _C_ops.uniform(
shape,
dtype,
float(min),
float(max),
seed,
_current_expected_place(),
place,
out=out,
)
if requires_grad is True:
tensor.stop_gradient = False
return tensor
elif in_pir_mode():
check_type(
shape, 'shape', (list, tuple, paddle.pir.Value), 'uniform/rand'
Expand All @@ -1482,14 +1588,23 @@ def uniform(
if isinstance(max, int):
max = float(max)

return _C_ops.uniform(
place = (
_current_expected_place()
if device is None
else _get_paddle_place(device)
)
tensor = _C_ops.uniform(
shape,
dtype,
min,
max,
seed,
_current_expected_place(),
place,
out=out,
)
if requires_grad is True:
tensor.stop_gradient = False
return tensor
else:
check_type(shape, 'shape', (list, tuple, Variable), 'uniform/rand')
check_dtype(dtype, 'dtype', supported_dtypes, 'uniform/rand')
Expand Down
62 changes: 46 additions & 16 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@

from ..base.data_feeder import check_type, check_variable_and_dtype
from ..common_ops_import import Variable
from ..framework import LayerHelper, core
from ..framework import (
LayerHelper,
core,
)
from .math import _get_reduce_axis_with_tensor

if TYPE_CHECKING:
Expand Down Expand Up @@ -157,9 +160,12 @@ def mean(
def var(
x: Tensor,
axis: int | Sequence[int] | None = None,
unbiased: bool = True,
unbiased: bool | None = None,
keepdim: bool = False,
name: str | None = None,
*,
correction: float = 1,
out: Tensor | None = None,
) -> Tensor:
"""
Computes the variance of ``x`` along ``axis`` .
Expand All @@ -181,6 +187,9 @@ def var(
unbiased (bool, optional): Whether to use the unbiased estimation. If ``unbiased`` is True, the divisor used in the computation is :math:`N - 1`, where :math:`N` represents the number of elements along ``axis`` , otherwise the divisor is :math:`N`. Default is True.
keep_dim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the input unless keep_dim is true. Default is False.
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
correction (int|float, optional): Difference between the sample size and sample degrees of freedom.
Defaults to 1 (Bessel's correction). If unbiased is specified, this parameter is ignored.
out (Tensor|None, optional): Output tensor. Default is None.

Returns:
Tensor, results of variance along ``axis`` of ``x``, with the same data type as ``x``.
Expand All @@ -198,28 +207,41 @@ def var(
>>> print(out2.numpy())
[1. 4.3333335]
"""
if unbiased is not None and correction != 1:
raise ValueError("Only one of unbiased and correction may be given")

if unbiased is not None:
actual_correction = 1.0 if unbiased else 0.0
else:
actual_correction = float(correction)
if not in_dynamic_mode():
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'var'
)

u = mean(x, axis, True, name)
dtype = paddle.float32 if x.dtype == paddle.float16 else x.dtype
out = paddle.sum(
out_tensor = paddle.sum(
paddle.pow((x - u), 2), axis, keepdim=keepdim, name=name, dtype=dtype
)

n = paddle.cast(paddle.numel(x), "int64") / paddle.cast(
paddle.numel(out), "int64"
paddle.numel(out_tensor), "int64"
)
n = n.astype(dtype)
if unbiased:
one_const = paddle.ones([], x.dtype)
if paddle.in_dynamic_mode() and n <= one_const:

if actual_correction != 0:
corrected_n = n - actual_correction
corrected_n = paddle.maximum(
corrected_n, paddle.zeros_like(corrected_n)
)
if paddle.in_dynamic_mode() and paddle.any(corrected_n <= 0):
warnings.warn("Degrees of freedom is <= 0.", stacklevel=2)
n = n - 1.0
n.stop_gradient = True
out /= n
else:
corrected_n = n

corrected_n.stop_gradient = True
out_tensor /= corrected_n

def _replace_nan(out):
indices = paddle.arange(out.numel(), dtype='int64')
Expand All @@ -229,12 +251,20 @@ def _replace_nan(out):
return out_nan

if 0 in x.shape:
out = _replace_nan(out)
if len(x.shape) == 0 and not unbiased:
out = paddle.to_tensor(0, stop_gradient=out.stop_gradient)
if out.dtype != x.dtype:
return out.astype(x.dtype)
return out
out_tensor = _replace_nan(out_tensor)
if len(x.shape) == 0 and actual_correction == 0:
out_tensor = paddle.to_tensor(0, stop_gradient=out_tensor.stop_gradient)

if out_tensor.dtype != x.dtype:
result = out_tensor.astype(x.dtype)
else:
result = out_tensor

if out is not None:
paddle.assign(result, out)
return out

return result


def std(
Expand Down
Loading