Skip to content

Commit ea786b8

Browse files
ooooo-createLuckycheng222
authored andcommitted
refine full and full_like for fill_value type check and annotations (PaddlePaddle#74127)
* refine full and full_like for fill_value check and type annotations * refine * refine * refine * pass approve ci * refine code * adapt string numeric values usage * add more comments * add more tests
1 parent f17dd25 commit ea786b8

3 files changed

Lines changed: 135 additions & 8 deletions

File tree

python/paddle/tensor/creation.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import builtins
1818
import math
19+
import numbers
1920
import re
2021
import warnings
2122
from typing import TYPE_CHECKING, overload
@@ -1043,7 +1044,7 @@ def get_slice(
10431044

10441045
def full_like(
10451046
x: paddle.Tensor,
1046-
fill_value: bool | float,
1047+
fill_value: Numeric | str,
10471048
dtype: DTypeLike | None = None,
10481049
name: str | None = None,
10491050
*,
@@ -1057,9 +1058,10 @@ def full_like(
10571058
10581059
Args:
10591060
x(Tensor): The input tensor which specifies shape and data type. The data type can be bool, float16, float32, float64, int32, int64.
1060-
fill_value(bool|float|int): The value to fill the tensor with. Note: this value shouldn't exceed the range of the output data type.
1061+
fill_value(Scalar|Tensor): The value to fill the tensor with. Note: this value shouldn't exceed the range of the output data type.
1062+
If ``fill_value`` is an Tensor, it should be an 0-D Tensor which represents a scalar.
10611063
dtype(np.dtype|str, optional): The data type of output. The data type can be one
1062-
of bool, float16, float32, float64, int32, int64. The default value is None, which means the output
1064+
of bool, float16, float32, float64, int32, int64, complex64, complex128. The default value is None, which means the output
10631065
data type is the same as input.
10641066
name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
10651067
device(PlaceLike|None, optional): The desired device of returned tensor.
@@ -1081,6 +1083,15 @@ def full_like(
10811083
[[2. 2. 2.]
10821084
[2. 2. 2.]]
10831085
"""
1086+
# Include str type check to handle string numeric values like "0.5" that occur in CI tests.
1087+
# The compatible method for fliud operators, may be it can be removed in the future.
1088+
if not isinstance(
1089+
fill_value,
1090+
(numbers.Number, str, core.eager.Tensor, Variable, paddle.pir.Value),
1091+
):
1092+
raise TypeError(
1093+
f"The fill_value should be int, float, bool, complex, np.number, string numeric value or Tensor, but received {type(fill_value)}."
1094+
)
10841095

10851096
if dtype is None:
10861097
dtype = x.dtype
@@ -1635,7 +1646,7 @@ def _check_attr(attr, message):
16351646
@ParamAliasDecorator({"shape": ["size"]})
16361647
def full(
16371648
shape: ShapeLike,
1638-
fill_value: bool | float | paddle.Tensor,
1649+
fill_value: Numeric | str,
16391650
dtype: DTypeLike | None = None,
16401651
name: str | None = None,
16411652
*,
@@ -1656,10 +1667,10 @@ def full(
16561667
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
16571668
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
16581669
Alias: ``size``.
1659-
fill_value(bool|float|int|Tensor): The constant value used to initialize the Tensor to be created.
1670+
fill_value(Scalar|Tensor): The constant value used to initialize the Tensor to be created.
16601671
If ``fill_value`` is an Tensor, it should be an 0-D Tensor which represents a scalar.
16611672
dtype(np.dtype|str, optional): Data type of the output Tensor
1662-
which can be float16, float32, float64, int32, int64, if dtype is `None`, the data
1673+
which can be float16, float32, float64, int32, int64, complex64, complex128. If dtype is `None`, the data
16631674
type of created Tensor is `float32`.
16641675
name(str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
16651676
out(Tensor, optional): The output tensor.
@@ -1707,6 +1718,15 @@ def full(
17071718
[2. 2.]
17081719
[2. 2.]]
17091720
"""
1721+
# Include str type check to handle string numeric values like "0.5" that occur in CI tests.
1722+
# The compatible method for fliud operators, may be it can be removed in the future.
1723+
if not isinstance(
1724+
fill_value,
1725+
(numbers.Number, str, core.eager.Tensor, Variable, paddle.pir.Value),
1726+
):
1727+
raise TypeError(
1728+
f"The fill_value should be int, float, bool, complex, np.number, string numeric values or Tensor, but received {type(fill_value)}."
1729+
)
17101730

17111731
if dtype is None:
17121732
if isinstance(fill_value, (bool)):

test/legacy_test/test_full_like_op.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
from op_test import OpTest, convert_float_to_uint16
19+
from utils import dygraph_guard, static_guard
1920

2021
import paddle
2122
import paddle.framework.dtype as dtypes
@@ -41,7 +42,7 @@ def fill_any_like_wrapper(x, value, out_dtype=None, name=None):
4142
return paddle.full_like(x, value, tmp_dtype, name=name)
4243

4344

44-
class TestFullOp(unittest.TestCase):
45+
class TestFullLikeOp(unittest.TestCase):
4546
"""Test fill_any_like op(whose API is full_like) for attr out."""
4647

4748
def test_attr_tensor_API(self):
@@ -94,7 +95,8 @@ def test_full_like_fill_inf(self):
9495
paddle.enable_static()
9596

9697

97-
class TestFullOpError(unittest.TestCase):
98+
class TestFullLikeOpError(unittest.TestCase):
99+
98100
def test_errors(self):
99101
with paddle.static.program_guard(
100102
paddle.static.Program(), paddle.static.Program()
@@ -114,6 +116,33 @@ def test_errors(self):
114116
dtype='uint4',
115117
)
116118

119+
def test_fill_value_errors(self):
120+
with dygraph_guard():
121+
# The fill_value must be one of [int, float, bool, complex, Tensor, np.number].
122+
self.assertRaises(
123+
TypeError,
124+
paddle.full_like,
125+
x=paddle.to_tensor([1.0, 2.0]),
126+
fill_value=np.array([1.0], dtype=np.float32),
127+
dtype="float32",
128+
)
129+
130+
self.assertRaises(
131+
TypeError,
132+
paddle.full_like,
133+
x=paddle.to_tensor([1.0, 2.0]),
134+
fill_value=[1.0],
135+
dtype="float32",
136+
)
137+
138+
self.assertRaises(
139+
TypeError,
140+
paddle.full_like,
141+
x=paddle.to_tensor([1.0, 2.0]),
142+
fill_value=np.bool_(True),
143+
dtype="bool",
144+
)
145+
117146

118147
class TestFullLikeOp1(OpTest):
119148
# test basic
@@ -198,6 +227,16 @@ def test_skip_data_transform(self):
198227
paddle.enable_static()
199228

200229

230+
class TestFullLikeOp5(TestFullLikeOp1):
231+
def init_data(self):
232+
self.fill_value = True
233+
self.shape = [10, 10]
234+
self.dtype = np.bool
235+
236+
def if_enable_cinn(self):
237+
pass
238+
239+
201240
class TestFullLikeFP16Op(TestFullLikeOp1):
202241
def init_data(self):
203242
self.fill_value = 6666
@@ -268,5 +307,45 @@ def test_full_like_kernel_gpu_zero_size(self):
268307
paddle.enable_static()
269308

270309

310+
class TestFullLikeWithTensorValue(unittest.TestCase):
311+
def test_dygraph_api(self):
312+
with dygraph_guard():
313+
base_np = np.array([[1, 2], [3, 4]], dtype=np.float32)
314+
value_np = np.array([5.0], dtype=np.float32)
315+
base_tensor = paddle.to_tensor(base_np)
316+
value_tensor = paddle.to_tensor(value_np)
317+
result = paddle.full_like(base_tensor, value_tensor)
318+
expected = np.full_like(base_np, value_np)
319+
np.testing.assert_array_equal(result.numpy(), expected)
320+
321+
def test_static_api(self):
322+
with static_guard():
323+
startup_program = paddle.static.Program()
324+
train_program = paddle.static.Program()
325+
with paddle.static.program_guard(train_program, startup_program):
326+
base_tensor = paddle.static.data(
327+
name='base_tensor', dtype='float32', shape=[2, 2]
328+
)
329+
value_tensor = paddle.static.data(
330+
name='value_tensor', dtype='float32', shape=[1]
331+
)
332+
result = paddle.full_like(base_tensor, value_tensor)
333+
334+
place = paddle.CPUPlace()
335+
exe = paddle.static.Executor(place)
336+
337+
base_np = np.array([[1, 2], [3, 4]], dtype=np.float32)
338+
value_np = np.array([5.0], dtype=np.float32)
339+
340+
res = exe.run(
341+
train_program,
342+
feed={'base_tensor': base_np, 'value_tensor': value_np},
343+
fetch_list=[result],
344+
)
345+
346+
expected = np.full_like(base_np, value_np)
347+
np.testing.assert_array_equal(res[0], expected)
348+
349+
271350
if __name__ == "__main__":
272351
unittest.main()

test/legacy_test/test_full_op.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18+
from utils import dygraph_guard
1819

1920
import paddle
2021
from paddle import base
@@ -444,6 +445,33 @@ def test_shape_tensor_list_dtype():
444445
self.assertRaises(TypeError, test_shape_tensor_list_dtype)
445446
paddle.disable_static()
446447

448+
def test_fill_value_errors(self):
449+
with dygraph_guard():
450+
# The fill_value must be one of [int, float, bool, complex, np.number, Tensor].
451+
self.assertRaises(
452+
TypeError,
453+
paddle.full,
454+
shape=[1],
455+
dtype="float32",
456+
fill_value=np.array([1.0], dtype=np.float32),
457+
)
458+
459+
self.assertRaises(
460+
TypeError,
461+
paddle.full,
462+
shape=[1],
463+
dtype="float32",
464+
fill_value=[1.0],
465+
)
466+
467+
self.assertRaises(
468+
TypeError,
469+
paddle.full,
470+
shape=[1],
471+
dtype="bool",
472+
fill_value=np.bool_(True),
473+
)
474+
447475

448476
if __name__ == "__main__":
449477
unittest.main()

0 commit comments

Comments
 (0)