diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index bb563be086aaa1..87143897b3a419 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -655,6 +655,7 @@ def new_init(self, *args, **kwargs): normal_, poisson, rand, + rand_like, randint, randint_like, randn, @@ -1246,6 +1247,7 @@ def __dir__(self): 'geometric_', 'randn', 'randn_like', + 'rand_like', 'strided_slice', 'unique', 'unique_consecutive', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ec80bb6e6cea38..97dd26c97c3d2b 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -454,6 +454,7 @@ normal_, poisson, rand, + rand_like, randint, randint_like, randn, diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 3c7a4b4beae75c..e956d0fc9bf1b1 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -453,6 +453,8 @@ def multinomial( num_samples: int = 1, replacement: bool = False, name: str | None = None, + *, + out: Tensor | None = None, ) -> Tensor: """ Returns a Tensor filled with random values sampled from a Multinomial @@ -474,6 +476,7 @@ def multinomial( name(str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + out (Tensor|None, optional): The output Tensor. If set, the result will be stored in this Tensor. Default is None. Returns: Tensor, A Tensor filled with sampled category index after ``num_samples`` times samples. @@ -516,7 +519,7 @@ def multinomial( """ if in_dynamic_or_pir_mode(): - return _C_ops.multinomial(x, num_samples, replacement) + return _C_ops.multinomial(x, num_samples, replacement, out=out) else: check_variable_and_dtype( x, "x", ["uint16", "float16", "float32", "float64"], "multinomial" @@ -1150,14 +1153,104 @@ def randn_like( """ if dtype is None: dtype = x.dtype - else: - if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): - dtype = convert_np_dtype_to_dtype_(dtype) shape = paddle.shape(x) return standard_normal(shape, dtype, name) +def rand_like( + input, + name: str | None = None, + *, + dtype: DTypeLike | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, +): + """ + Returns a tensor with the same size as input that is filled with random numbers from a uniform distribution on the interval [0, 1). + + Args: + input (Tensor): The input multi-dimensional tensor which specifies shape. The dtype of ``input`` + can be float16, float64, float8_e4m3fn, float32, bfloat16. + name (str|None, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + dtype (str|np.dtype|paddle.dtype|None, optional): The data type of the + output tensor. Supported data types: float16, float64, float8_e4m3fn, float32, bfloat16. + If ``dtype`` is None, the data type is the same as input's data type. Default is None. + device (str|paddle.Place|None, optional): The device on which to place the created tensor. + If None, the device is the same as input's device. Default is None. + requires_grad (bool, optional): Whether to compute gradients for the created tensor. + Default is False. + + Returns: + Tensor: A Tensor with the same size as input that is filled with random numbers from a uniform distribution on the interval [0, 1). + + Examples: + .. code-block:: python + + >>> import paddle + + >>> # example 1: + >>> # dtype is None and the dtype of input is float32 + >>> x = paddle.zeros((2, 3)).astype("float32") + >>> out1 = paddle.rand_like(x) + >>> print(out1) + >>> # doctest: +SKIP("Random output") + Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[0.34962332, 0.82356787, 0.91275704], + [0.12328923, 0.58439839, 0.32735515]]) + >>> # doctest: -SKIP + >>> print(out1.dtype) + paddle.float32 + + >>> # example 2: + >>> # dtype is None and the dtype of input is float64 + >>> x = paddle.zeros((2, 3)).astype("float64") + >>> out2 = paddle.rand_like(x) + >>> print(out2) + >>> # doctest: +SKIP("Random output") + Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True, + [[0.73964721, 0.28413662, 0.91918457], + [0.62838351, 0.39185921, 0.51561823]]) + >>> # doctest: -SKIP + >>> print(out2.dtype) + paddle.float64 + + >>> # example 3: + >>> # dtype is float64 and the dtype of input is float32 + >>> x = paddle.zeros((2, 3)).astype("float32") + >>> out3 = paddle.rand_like(x, dtype="float64") + >>> print(out3) + >>> # doctest: +SKIP("Random output") + Tensor(shape=[2, 3], dtype=float64, place=Place(cpu), stop_gradient=True, + [[0.84492219, 0.11572551, 0.73868765], + [0.90269387, 0.45644298, 0.28739912]]) + >>> # doctest: -SKIP + >>> print(out3.dtype) + paddle.float64 + + >>> # example 4: + >>> # with requires_grad=True + >>> x = paddle.zeros((2, 2)).astype("float32") + >>> out4 = paddle.rand_like(x, requires_grad=True) + >>> print(out4.stop_gradient) + False + """ + if dtype is None: + dtype = input.dtype + + return uniform( + shape=input.shape, + dtype=dtype, + min=0.0, + max=1.0, + name=name, + device=device, + requires_grad=requires_grad, + ) + + def normal( mean: complex | Tensor = 0.0, std: float | Tensor = 1.0, @@ -1370,6 +1463,10 @@ def uniform( max: float = 1.0, seed: int = 0, name: str | None = None, + *, + out: Tensor | None = None, + device: PlaceLike | None = None, + requires_grad: bool = False, ) -> Tensor: """ Returns a Tensor filled with random values sampled from a uniform @@ -1460,14 +1557,23 @@ def uniform( if in_dynamic_mode(): shape = paddle.utils.convert_shape_to_list(shape) - return _C_ops.uniform( + place = ( + _current_expected_place() + if device is None + else _get_paddle_place(device) + ) + tensor = _C_ops.uniform( shape, dtype, float(min), float(max), seed, - _current_expected_place(), + place, + out=out, ) + if requires_grad is True: + tensor.stop_gradient = False + return tensor elif in_pir_mode(): check_type( shape, 'shape', (list, tuple, paddle.pir.Value), 'uniform/rand' @@ -1482,14 +1588,23 @@ def uniform( if isinstance(max, int): max = float(max) - return _C_ops.uniform( + place = ( + _current_expected_place() + if device is None + else _get_paddle_place(device) + ) + tensor = _C_ops.uniform( shape, dtype, min, max, seed, - _current_expected_place(), + place, + out=out, ) + if requires_grad is True: + tensor.stop_gradient = False + return tensor else: check_type(shape, 'shape', (list, tuple, Variable), 'uniform/rand') check_dtype(dtype, 'dtype', supported_dtypes, 'uniform/rand') diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 1344a620dc8e66..83f550a2ec12d6 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -32,7 +32,10 @@ from ..base.data_feeder import check_type, check_variable_and_dtype from ..common_ops_import import Variable -from ..framework import LayerHelper, core +from ..framework import ( + LayerHelper, + core, +) from .math import _get_reduce_axis_with_tensor if TYPE_CHECKING: @@ -157,9 +160,12 @@ def mean( def var( x: Tensor, axis: int | Sequence[int] | None = None, - unbiased: bool = True, + unbiased: bool | None = None, keepdim: bool = False, name: str | None = None, + *, + correction: float = 1, + out: Tensor | None = None, ) -> Tensor: """ Computes the variance of ``x`` along ``axis`` . @@ -181,6 +187,9 @@ def var( unbiased (bool, optional): Whether to use the unbiased estimation. If ``unbiased`` is True, the divisor used in the computation is :math:`N - 1`, where :math:`N` represents the number of elements along ``axis`` , otherwise the divisor is :math:`N`. Default is True. keep_dim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the input unless keep_dim is true. Default is False. name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + correction (int|float, optional): Difference between the sample size and sample degrees of freedom. + Defaults to 1 (Bessel's correction). If unbiased is specified, this parameter is ignored. + out (Tensor|None, optional): Output tensor. Default is None. Returns: Tensor, results of variance along ``axis`` of ``x``, with the same data type as ``x``. @@ -198,6 +207,13 @@ def var( >>> print(out2.numpy()) [1. 4.3333335] """ + if unbiased is not None and correction != 1: + raise ValueError("Only one of unbiased and correction may be given") + + if unbiased is not None: + actual_correction = 1.0 if unbiased else 0.0 + else: + actual_correction = float(correction) if not in_dynamic_mode(): check_variable_and_dtype( x, 'x', ['float16', 'float32', 'float64'], 'var' @@ -205,21 +221,27 @@ def var( u = mean(x, axis, True, name) dtype = paddle.float32 if x.dtype == paddle.float16 else x.dtype - out = paddle.sum( + out_tensor = paddle.sum( paddle.pow((x - u), 2), axis, keepdim=keepdim, name=name, dtype=dtype ) n = paddle.cast(paddle.numel(x), "int64") / paddle.cast( - paddle.numel(out), "int64" + paddle.numel(out_tensor), "int64" ) n = n.astype(dtype) - if unbiased: - one_const = paddle.ones([], x.dtype) - if paddle.in_dynamic_mode() and n <= one_const: + + if actual_correction != 0: + corrected_n = n - actual_correction + corrected_n = paddle.maximum( + corrected_n, paddle.zeros_like(corrected_n) + ) + if paddle.in_dynamic_mode() and paddle.any(corrected_n <= 0): warnings.warn("Degrees of freedom is <= 0.", stacklevel=2) - n = n - 1.0 - n.stop_gradient = True - out /= n + else: + corrected_n = n + + corrected_n.stop_gradient = True + out_tensor /= corrected_n def _replace_nan(out): indices = paddle.arange(out.numel(), dtype='int64') @@ -229,12 +251,20 @@ def _replace_nan(out): return out_nan if 0 in x.shape: - out = _replace_nan(out) - if len(x.shape) == 0 and not unbiased: - out = paddle.to_tensor(0, stop_gradient=out.stop_gradient) - if out.dtype != x.dtype: - return out.astype(x.dtype) - return out + out_tensor = _replace_nan(out_tensor) + if len(x.shape) == 0 and actual_correction == 0: + out_tensor = paddle.to_tensor(0, stop_gradient=out_tensor.stop_gradient) + + if out_tensor.dtype != x.dtype: + result = out_tensor.astype(x.dtype) + else: + result = out_tensor + + if out is not None: + paddle.assign(result, out) + return out + + return result def std( diff --git a/test/legacy_test/test_multinomial_op.py b/test/legacy_test/test_multinomial_op.py index 8f8bf75be5e3be..5dad7afbe841a2 100644 --- a/test/legacy_test/test_multinomial_op.py +++ b/test/legacy_test/test_multinomial_op.py @@ -340,6 +340,144 @@ def test_static(self): ) +class TestMultinomialOutParameter(unittest.TestCase): + def setUp(self): + paddle.disable_static() + paddle.seed(100) + + def tearDown(self): + paddle.enable_static() + + def test_out_parameter_basic(self): + x_numpy = np.random.rand(4) + x = paddle.to_tensor(x_numpy) + + out = paddle.empty([1000], dtype='int64') + paddle.multinomial(x, num_samples=1000, replacement=True, out=out) + + self.assertEqual(out.shape, [1000]) + self.assertEqual(out.dtype, paddle.int64) + + self.assertTrue(paddle.all(out >= 0)) + self.assertTrue(paddle.all(out < 4)) + + def test_out_parameter_2d(self): + x_numpy = np.random.rand(3, 4) + x = paddle.to_tensor(x_numpy) + + out = paddle.empty([3, 100], dtype='int64') + + paddle.multinomial(x, num_samples=100, replacement=True, out=out) + + self.assertEqual(out.shape, [3, 100]) + self.assertEqual(out.dtype, paddle.int64) + + self.assertTrue(paddle.all(out >= 0)) + self.assertTrue(paddle.all(out < 4)) + + def test_out_parameter_with_alias(self): + x_numpy = np.random.rand(4) + x = paddle.to_tensor(x_numpy) + + out = paddle.empty([1000], dtype='int64') + paddle.multinomial(input=x, num_samples=1000, replacement=True, out=out) + + self.assertEqual(out.shape, [1000]) + self.assertEqual(out.dtype, paddle.int64) + + def test_out_parameter_different_scenarios(self): + x_numpy = np.random.rand(100) + x = paddle.to_tensor(x_numpy) + out = paddle.empty([50], dtype='int64') + + paddle.multinomial(x, num_samples=50, replacement=False, out=out) + + unique_values = paddle.unique(out) + self.assertEqual(len(unique_values), 50) + + out_small = paddle.empty([5], dtype='int64') + paddle.multinomial(x, num_samples=5, replacement=True, out=out_small) + self.assertEqual(out_small.shape, [5]) + + def test_out_parameter_none_default(self): + x_numpy = np.random.rand(4) + x = paddle.to_tensor(x_numpy) + + result1 = paddle.multinomial( + x, num_samples=100, replacement=True, out=None + ) + result2 = paddle.multinomial(x, num_samples=100, replacement=True) + + self.assertEqual(result1.shape, result2.shape) + self.assertEqual(result1.dtype, result2.dtype) + + +class TestMultinomialOutAndAliasDecorator(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def do_test(self, test_type): + x_numpy = np.random.rand(4) + x = paddle.to_tensor(x_numpy, stop_gradient=False) + + if test_type == "raw": + result = paddle.multinomial(x, num_samples=1000, replacement=True) + loss = paddle.cast(result, 'float32').mean() + loss.backward() + return result, x.grad + + elif test_type == "alias": + result = paddle.multinomial( + input=x, num_samples=1000, replacement=True + ) + loss = paddle.cast(result, 'float32').mean() + loss.backward() + return result, x.grad + + elif test_type == "out": + out = paddle.empty([1000], dtype='int64') + out.stop_gradient = False + paddle.multinomial(x, num_samples=1000, replacement=True, out=out) + loss = paddle.cast(out, 'float32').mean() + loss.backward() + return out, x.grad + + elif test_type == "out_alias": + out = paddle.empty([1000], dtype='int64') + out.stop_gradient = False + paddle.multinomial( + input=x, num_samples=1000, replacement=True, out=out + ) + loss = paddle.cast(out, 'float32').mean() + loss.backward() + return out, x.grad + + else: + raise ValueError(f"Unknown test type: {test_type}") + + def test_multinomial_out_and_alias_combination(self): + test_types = ["raw", "alias", "out", "out_alias"] + + results = {} + grads = {} + + for test_type in test_types: + paddle.seed(42) + result, grad = self.do_test(test_type) + results[test_type] = result + grads[test_type] = grad + + base_shape = results["raw"].shape + base_dtype = results["raw"].dtype + + for test_type in test_types: + self.assertEqual(results[test_type].shape, base_shape) + self.assertEqual(results[test_type].dtype, base_dtype) + + class TestMultinomialAlias(unittest.TestCase): def test_alias(self): paddle.disable_static() diff --git a/test/legacy_test/test_rand_like.py b/test/legacy_test/test_rand_like.py new file mode 100644 index 00000000000000..d5f132245fc720 --- /dev/null +++ b/test/legacy_test/test_rand_like.py @@ -0,0 +1,310 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base, core + + +class TestRandLikeAPI(unittest.TestCase): + """ + Test python API for rand_like function. + """ + + def setUp(self): + self.x_float16 = np.zeros((10, 12)).astype("float16") + self.x_float32 = np.zeros((10, 12)).astype("float32") + self.x_float64 = np.zeros((10, 12)).astype("float64") + self.dtype = ["float16", "float32", "float64"] + + def test_static_api_basic(self): + """Test basic static API functionality""" + paddle.enable_static() + try: + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_float32 = paddle.static.data( + name="x_float32", shape=[10, 12], dtype="float32" + ) + + # Test with default parameters + out1 = paddle.rand_like(x_float32) + + # Test with specified name + out2 = paddle.rand_like(x_float32, name="test_rand_like") + + place = base.CPUPlace() + if core.is_compiled_with_cuda(): + place = base.CUDAPlace(0) + + exe = paddle.static.Executor(place) + outs = exe.run( + feed={'x_float32': self.x_float32}, fetch_list=[out1, out2] + ) + + for out in outs: + self.assertEqual(out.shape, (10, 12)) + self.assertEqual(out.dtype, np.float32) + self.assertTrue(((out >= 0.0) & (out <= 1.0)).all()) + finally: + paddle.disable_static() + + def test_static_api_with_dtype(self): + """Test static API with different dtype specifications""" + paddle.enable_static() + try: + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_float32 = paddle.static.data( + name="x_float32", shape=[10, 12], dtype="float32" + ) + + place = base.CPUPlace() + if core.is_compiled_with_cuda(): + place = base.CUDAPlace(0) + + exe = paddle.static.Executor(place) + + # Test with different dtypes + for dtype in self.dtype: + if dtype == "float16" and not core.is_compiled_with_cuda(): + continue + + out = paddle.rand_like(x_float32, dtype=dtype) + result = exe.run( + feed={'x_float32': self.x_float32}, fetch_list=[out] + )[0] + + self.assertEqual(result.shape, (10, 12)) + self.assertEqual(result.dtype, np.dtype(dtype)) + self.assertTrue(((result >= 0.0) & (result <= 1.0)).all()) + finally: + paddle.disable_static() + + def test_static_api_with_device(self): + """Test static API with device specification""" + paddle.enable_static() + try: + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + x_float32 = paddle.static.data( + name="x_float32", shape=[10, 12], dtype="float32" + ) + + # Test with CPU device + out1 = paddle.rand_like(x_float32, device=base.CPUPlace()) + + place = base.CPUPlace() + exe = paddle.static.Executor(place) + result = exe.run( + feed={'x_float32': self.x_float32}, fetch_list=[out1] + )[0] + + self.assertEqual(result.shape, (10, 12)) + self.assertTrue(((result >= 0.0) & (result <= 1.0)).all()) + + # Test with CUDA device if available + if core.is_compiled_with_cuda(): + out2 = paddle.rand_like(x_float32, device=base.CUDAPlace(0)) + place_cuda = base.CUDAPlace(0) + exe_cuda = paddle.static.Executor(place_cuda) + result_cuda = exe_cuda.run( + feed={'x_float32': self.x_float32}, fetch_list=[out2] + )[0] + + self.assertEqual(result_cuda.shape, (10, 12)) + self.assertTrue( + ((result_cuda >= 0.0) & (result_cuda <= 1.0)).all() + ) + finally: + paddle.disable_static() + + def test_dygraph_api_basic(self): + """Test basic dygraph API functionality""" + for x_np in [self.x_float32, self.x_float64]: + x = paddle.to_tensor(x_np) + + # Test with default parameters + out1 = paddle.rand_like(x) + self.assertEqual(out1.shape, x.shape) + self.assertEqual(out1.dtype, x.dtype) + self.assertTrue( + ((out1.numpy() >= 0.0) & (out1.numpy() <= 1.0)).all() + ) + + # Test with name parameter + out2 = paddle.rand_like(x, name="test_rand_like") + self.assertEqual(out2.shape, x.shape) + self.assertEqual(out2.dtype, x.dtype) + self.assertTrue( + ((out2.numpy() >= 0.0) & (out2.numpy() <= 1.0)).all() + ) + + # Test with float16 if CUDA is available + if core.is_compiled_with_cuda(): + x = paddle.to_tensor(self.x_float16) + out = paddle.rand_like(x) + self.assertEqual(out.shape, x.shape) + self.assertEqual(out.dtype, x.dtype) + self.assertTrue(((out.numpy() >= 0.0) & (out.numpy() <= 1.0)).all()) + + def test_dygraph_api_with_dtype(self): + """Test dygraph API with different dtype specifications""" + x = paddle.to_tensor(self.x_float32) + + for dtype in self.dtype: + if dtype == "float16" and not core.is_compiled_with_cuda(): + continue + + out = paddle.rand_like(x, dtype=dtype) + self.assertEqual(out.shape, x.shape) + self.assertEqual(out.dtype, getattr(paddle, dtype)) + self.assertTrue(((out.numpy() >= 0.0) & (out.numpy() <= 1.0)).all()) + + def test_dygraph_api_with_requires_grad(self): + """Test dygraph API with requires_grad parameter""" + x = paddle.to_tensor(self.x_float32) + + # Test requires_grad=True + out1 = paddle.rand_like(x, requires_grad=True) + self.assertEqual(out1.shape, x.shape) + self.assertFalse(out1.stop_gradient) + self.assertTrue(((out1.numpy() >= 0.0) & (out1.numpy() <= 1.0)).all()) + + # Test requires_grad=False + out2 = paddle.rand_like(x, requires_grad=False) + self.assertEqual(out2.shape, x.shape) + self.assertTrue(out2.stop_gradient) + self.assertTrue(((out2.numpy() >= 0.0) & (out2.numpy() <= 1.0)).all()) + + def test_dygraph_api_with_device(self): + """Test dygraph API with device specification""" + x = paddle.to_tensor(self.x_float32) + + # Test with CPU device + out1 = paddle.rand_like(x, device=paddle.CPUPlace()) + self.assertEqual(out1.shape, x.shape) + self.assertEqual(out1.dtype, x.dtype) + self.assertTrue(out1.place.is_cpu_place()) + self.assertTrue(((out1.numpy() >= 0.0) & (out1.numpy() <= 1.0)).all()) + + # Test with CUDA device if available + if core.is_compiled_with_cuda(): + out2 = paddle.rand_like(x, device=paddle.CUDAPlace(0)) + self.assertEqual(out2.shape, x.shape) + self.assertEqual(out2.dtype, x.dtype) + self.assertTrue(out2.place.is_gpu_place()) + self.assertTrue( + ((out2.numpy() >= 0.0) & (out2.numpy() <= 1.0)).all() + ) + + def test_dygraph_api_combined_params(self): + """Test dygraph API with combined parameters""" + x = paddle.to_tensor(self.x_float32) + + # Test dtype + requires_grad + out1 = paddle.rand_like(x, dtype="float64", requires_grad=True) + self.assertEqual(out1.shape, x.shape) + self.assertEqual(out1.dtype, paddle.float64) + self.assertFalse(out1.stop_gradient) + self.assertTrue(((out1.numpy() >= 0.0) & (out1.numpy() <= 1.0)).all()) + + # Test all parameters together + out2 = paddle.rand_like( + x, name="combined_test", dtype="float64", requires_grad=False + ) + self.assertEqual(out2.shape, x.shape) + self.assertEqual(out2.dtype, paddle.float64) + self.assertTrue(out2.stop_gradient) + self.assertTrue(((out2.numpy() >= 0.0) & (out2.numpy() <= 1.0)).all()) + + def test_different_shapes(self): + """Test with different input shapes""" + shapes = [ + [ + 1, + ], + [5, 3], + [2, 4, 6], + [1, 2, 3, 4], + ] + + for shape in shapes: + x = paddle.zeros(shape, dtype='float32') + out = paddle.rand_like(x) + self.assertEqual(out.shape, shape) + self.assertTrue(((out.numpy() >= 0.0) & (out.numpy() <= 1.0)).all()) + + def test_default_dtype_behavior(self): + """Test default dtype behavior""" + # Test that output dtype matches input dtype when dtype=None + dtypes_to_test = ['float32', 'float64'] + if core.is_compiled_with_cuda(): + dtypes_to_test.append('float16') + + for dtype_str in dtypes_to_test: + x = paddle.zeros((3, 4), dtype=dtype_str) + out = paddle.rand_like(x) # dtype=None (default) + self.assertEqual(out.dtype, x.dtype) + self.assertTrue(((out.numpy() >= 0.0) & (out.numpy() <= 1.0)).all()) + + +class TestRandLikeOpForDygraph(unittest.TestCase): + """ + Test rand_like operation in dygraph mode with different scenarios. + """ + + def run_net(self, use_cuda=False): + place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() + with base.dygraph.guard(place): + # Test basic functionality + x1 = paddle.zeros([3, 4], dtype='float32') + out1 = paddle.rand_like(x1) + + # Test with different dtype + x2 = paddle.zeros([3, 4], dtype='float32') + out2 = paddle.rand_like(x2, dtype='float64') + + # Test with requires_grad + x3 = paddle.zeros([2, 5], dtype='float32') + out3 = paddle.rand_like(x3, requires_grad=True) + + # Test with device specification + x4 = paddle.zeros([4, 3], dtype='float32') + out4 = paddle.rand_like(x4, device=place) + + # Test with all parameters including device + x5 = paddle.zeros([2, 3], dtype='float32') + out5 = paddle.rand_like( + x5, + name="test_all_params", + dtype='float64', + device=place, + requires_grad=False, + ) + + def test_run(self): + self.run_net(False) + if core.is_compiled_with_cuda(): + self.run_net(True) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/legacy_test/test_variance_layer.py b/test/legacy_test/test_variance_layer.py index cd1a3842660567..5db15535e8e3c7 100644 --- a/test/legacy_test/test_variance_layer.py +++ b/test/legacy_test/test_variance_layer.py @@ -184,5 +184,328 @@ def test_api(self): paddle.enable_static() +def ref_var_with_correction(x, axis=None, correction=1, keepdim=False): + if isinstance(axis, int): + axis = (axis,) + if axis is not None: + axis = tuple(axis) + return np.var(x, axis=axis, ddof=correction, keepdims=keepdim) + + +class TestVarAPI_Correction(TestVarAPI): + def set_attrs(self): + self.correction = 0 + self.use_correction = True + + def static(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.shape, self.dtype) + if self.use_correction: + out = paddle.var( + x, + self.axis, + keepdim=self.keepdim, + correction=self.correction, + ) + else: + out = paddle.var(x, self.axis, self.unbiased, self.keepdim) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + return res[0] + + def dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + if self.use_correction: + out = paddle.var( + x, self.axis, keepdim=self.keepdim, correction=self.correction + ) + else: + out = paddle.var(x, self.axis, self.unbiased, self.keepdim) + paddle.enable_static() + return out.numpy() + + def test_api(self): + if self.use_correction: + out_ref = ref_var_with_correction( + self.x, self.axis, self.correction, self.keepdim + ) + else: + out_ref = ref_var(self.x, self.axis, self.unbiased, self.keepdim) + out_dygraph = self.dygraph() + + np.testing.assert_allclose(out_ref, out_dygraph, rtol=1e-05) + self.assertTrue(np.equal(out_ref.shape, out_dygraph.shape).all()) + + def test_static_or_pir_mode(): + out_static = self.static() + np.testing.assert_allclose(out_ref, out_static, rtol=1e-05) + self.assertTrue(np.equal(out_ref.shape, out_static.shape).all()) + + test_static_or_pir_mode() + + +class TestVarAPI_Correction2(TestVarAPI_Correction): + def set_attrs(self): + self.correction = 2 + self.use_correction = True + + +class TestVarAPI_CorrectionFloat(TestVarAPI_Correction): + def set_attrs(self): + self.correction = 1.5 + self.use_correction = True + + +class TestVarAPI_CorrectionWithAxis(TestVarAPI_Correction): + def set_attrs(self): + self.correction = 0 + self.axis = [1, 2] + self.use_correction = True + + +class TestVarAPI_OutParameter(unittest.TestCase): + def setUp(self): + self.dtype = 'float64' + self.shape = [2, 3, 4] + self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + self.place = get_device_place() + + def test_out_parameter_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + + out = paddle.empty(self.shape, dtype=self.dtype) + result = paddle.var(x, out=out) + + self.assertTrue(paddle.equal_all(result, out)) + + expected = paddle.var(x) + np.testing.assert_allclose(result.numpy(), expected.numpy(), rtol=1e-05) + + paddle.enable_static() + + def test_out_parameter_with_axis(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + axis = 1 + + expected_shape = list(self.shape) + expected_shape.pop(axis) + + out = paddle.empty(expected_shape, dtype=self.dtype) + result = paddle.var(x, axis=axis, out=out) + + self.assertTrue(paddle.equal_all(result, out)) + + expected = paddle.var(x, axis=axis) + np.testing.assert_allclose(result.numpy(), expected.numpy(), rtol=1e-05) + + paddle.enable_static() + + def test_out_parameter_with_keepdim(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + axis = 1 + + expected_shape = list(self.shape) + expected_shape[axis] = 1 + + out = paddle.empty(expected_shape, dtype=self.dtype) + result = paddle.var(x, axis=axis, keepdim=True, out=out) + + self.assertTrue(paddle.equal_all(result, out)) + + expected = paddle.var(x, axis=axis, keepdim=True) + np.testing.assert_allclose(result.numpy(), expected.numpy(), rtol=1e-05) + + paddle.enable_static() + + def test_out_parameter_none(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + + result1 = paddle.var(x, out=None) + result2 = paddle.var(x) + + np.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-05) + + paddle.enable_static() + + +class TestVarAPI_CorrectionAndOut(unittest.TestCase): + def setUp(self): + self.dtype = 'float64' + self.shape = [2, 3, 4] + self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + + def test_correction_and_out_combination(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + correction = 0 + + out = paddle.empty([], dtype=self.dtype) + result = paddle.var(x, correction=correction, out=out) + + self.assertTrue(paddle.equal_all(result, out)) + + expected = paddle.var(x, correction=correction) + np.testing.assert_allclose(result.numpy(), expected.numpy(), rtol=1e-05) + + expected_np = np.var(self.x, ddof=correction) + np.testing.assert_allclose(result.numpy(), expected_np, rtol=1e-05) + + paddle.enable_static() + + def test_correction_and_out_with_axis(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + correction = 2 + axis = 1 + + expected_shape = list(self.shape) + expected_shape.pop(axis) + + out = paddle.empty(expected_shape, dtype=self.dtype) + result = paddle.var(x, axis=axis, correction=correction, out=out) + + self.assertTrue(paddle.equal_all(result, out)) + + expected = paddle.var(x, axis=axis, correction=correction) + np.testing.assert_allclose(result.numpy(), expected.numpy(), rtol=1e-05) + + expected_np = np.var(self.x, axis=axis, ddof=correction) + np.testing.assert_allclose(result.numpy(), expected_np, rtol=1e-05) + + paddle.enable_static() + + +class TestVarAPI_ParamAlias(unittest.TestCase): + def setUp(self): + self.dtype = 'float64' + self.shape = [2, 3, 4] + self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + + def test_input_alias(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + + result1 = paddle.var(x=x) + result2 = paddle.var(input=x) + + np.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-05) + + paddle.enable_static() + + def test_dim_alias(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + axis_val = 1 + + result1 = paddle.var(x, axis=axis_val) + result2 = paddle.var(x, dim=axis_val) + + np.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-05) + + paddle.enable_static() + + def test_all_aliases_combination(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + axis_val = [1, 2] + + result1 = paddle.var(x=x, axis=axis_val, unbiased=False, keepdim=True) + result2 = paddle.var( + input=x, dim=axis_val, unbiased=False, keepdim=True + ) + + np.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-05) + + paddle.enable_static() + + def test_alias_with_new_params(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + correction = 0 + + expected_shape = [] + out = paddle.empty(expected_shape, dtype=self.dtype) + + result = paddle.var(input=x, correction=correction, out=out) + + expected = paddle.var(x, correction=correction) + np.testing.assert_allclose(result.numpy(), expected.numpy(), rtol=1e-05) + + paddle.enable_static() + + def test_static_mode_aliases(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.shape, self.dtype) + + out = paddle.var(input=x, dim=1) + + exe = paddle.static.Executor(get_device_place()) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + + expected = np.var(self.x, axis=1, ddof=1) + np.testing.assert_allclose(res[0], expected, rtol=1e-05) + + +class TestVarAPI_CorrectionEdgeCases(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def tearDown(self): + paddle.enable_static() + + def test_correction_larger_than_sample_size(self): + x = paddle.to_tensor([1.0, 2.0, 3.0]) + + result = paddle.var(x, correction=3) + self.assertTrue(paddle.isinf(result) or paddle.isnan(result)) + + result = paddle.var(x, correction=4) + self.assertTrue(paddle.isinf(result) or paddle.isnan(result)) + + def test_correction_negative(self): + x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0]) + + result = paddle.var(x, correction=-1) + expected_np = np.var(x.numpy(), ddof=-1) + np.testing.assert_allclose(result.numpy(), expected_np, rtol=1e-05) + + def test_correction_zero(self): + x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0]) + + result1 = paddle.var(x, correction=0) + result2 = paddle.var(x, unbiased=False) + + np.testing.assert_allclose(result1.numpy(), result2.numpy(), rtol=1e-05) + + +class TestVarAPI_NewParamsAlias(TestVarAPI_alias): + def test_alias_with_new_parameters(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([1, 2, 3, 4], 'float32')) + + out1 = paddle.var(x, correction=0).numpy() + out2 = paddle.tensor.var(x, correction=0).numpy() + out3 = paddle.tensor.stat.var(x, correction=0).numpy() + np.testing.assert_allclose(out1, out2, rtol=1e-05) + np.testing.assert_allclose(out1, out3, rtol=1e-05) + + out_tensor = paddle.empty([], dtype='float32') + paddle.var(x, out=out_tensor) + result1 = out_tensor.numpy() + + out_tensor2 = paddle.empty([], dtype='float32') + paddle.tensor.var(x, out=out_tensor2) + result2 = out_tensor2.numpy() + + np.testing.assert_allclose(result1, result2, rtol=1e-05) + + paddle.enable_static() + + if __name__ == '__main__': unittest.main()