Skip to content

Commit e37c894

Browse files
committed
[API-Compat] Add paddle.compat.Unfold that supports tensor inputs.
1 parent c466f94 commit e37c894

6 files changed

Lines changed: 291 additions & 0 deletions

File tree

python/paddle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
_pir_ops as _pir_ops,
123123
_typing as _typing,
124124
callbacks as callbacks,
125+
compat as compat,
125126
fft as fft,
126127
hub as hub,
127128
linalg as linalg,

python/paddle/compat.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .tensor.compat import (
16+
Unfold,
17+
)
18+
19+
__all__ = ['Unfold']

python/paddle/nn/layer/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545

4646
_T_Padding = TypeVar("_T_Padding", Tensor, Sequence[int])
4747

48+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
49+
4850
__all__ = []
4951

5052

@@ -1908,6 +1910,11 @@ class Unfold(Layer):
19081910
strides: Size2
19091911
name: str | None
19101912

1913+
@ForbidKeywordsDecorator(
1914+
illegal_keys={"kernel_size", "dilation", "padding", "stride"},
1915+
func_name="paddle.nn.Unfold",
1916+
correct_name="paddle.compat.Unfold",
1917+
)
19111918
def __init__(
19121919
self,
19131920
kernel_sizes: Size2,

python/paddle/tensor/compat.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
import paddle
20+
from paddle import nn
21+
22+
if TYPE_CHECKING:
23+
24+
from paddle import Tensor
25+
from paddle._typing import (
26+
Size2,
27+
)
28+
29+
from paddle.framework import in_dynamic_mode
30+
from paddle.utils.decorator_utils import ForbidKeywordsDecorator
31+
32+
__all__ = []
33+
34+
35+
class Unfold(nn.Unfold):
36+
"""
37+
A compatible version of paddle.nn.Unfold:
38+
- The keyword arguments are in non-plural forms, example: `kernel_size` instead of kernel_sizes
39+
- `padding` restricts the size of the input to be 1(int) or 2, Size4 is not allowed. To use a more
40+
input-flexible version of Unfold, please refer to `paddle.nn.Unfold`.
41+
- All the input parameters allow `Tensor` or `pir.Value` as inputs, and will be converted to list
42+
43+
Other aspects are the same. See ``paddle.nn.Unfold`` for more details.
44+
45+
Parameters:
46+
kernel_size(int|list|tuple|Tensor): The size of convolution kernel, should be [k_h, k_w]
47+
or an integer k treated as [k, k].
48+
stride(int|list|tuple|Tensor, optional): The strides, should be [stride_h, stride_w]
49+
or an integer stride treated as [sride, stride]. For default, strides will be [1, 1].
50+
padding(int|list|tuple|Tensor, optional): The paddings of each dimension, should be
51+
a single integer or [padding_h, padding_w]. If [padding_h, padding_w] was given, it will expanded to
52+
[padding_h, padding_w, padding_h, padding_w]. If an integer padding was given,
53+
[padding, padding, padding, padding] will be used. By default, paddings will be 0.
54+
dilation(int|list|tuple|Tensor, optional): The dilations of convolution kernel, should be
55+
[dilation_h, dilation_w], or an integer dilation treated as [dilation, dilation].
56+
For default, it will be [1, 1].
57+
58+
Examples:
59+
.. code-block:: python
60+
61+
>>> import paddle
62+
63+
>>> x = paddle.randn((100, 3, 224, 224))
64+
>>> unfold = paddle.compat.Unfold(kernel_size=[3, 3])
65+
>>> result = unfold(x)
66+
>>> print(result.shape)
67+
[100, 27, 49284]
68+
69+
"""
70+
71+
kernel_sizes: Size2
72+
dilations: Size2
73+
paddings: Size2
74+
strides: Size2
75+
76+
@ForbidKeywordsDecorator(
77+
illegal_keys={"kernel_sizes", "dilations", "paddings", "strides"},
78+
func_name="paddle.compat.Unfold",
79+
correct_name="paddle.nn.Unfold",
80+
)
81+
def __init__(
82+
self,
83+
kernel_size: Size2,
84+
dilation: Size2 = 1,
85+
padding: Size2 = 0,
86+
stride: Size2 = 1,
87+
) -> None:
88+
89+
super().__init__(kernel_size, dilation, padding, stride)
90+
91+
def forward(self, input: Tensor) -> Tensor:
92+
def to_list_if_necessary(x, size_check=False):
93+
res = x
94+
if in_dynamic_mode() and isinstance(
95+
x, (paddle.pir.Value, paddle.Tensor)
96+
):
97+
res = x.tolist()
98+
else:
99+
if not isinstance(x, (list, tuple, int)):
100+
raise TypeError(
101+
"paddle.compat.Unfold does not allow paddle.Tensor or pir.Value as inputs in static graph mode."
102+
)
103+
if size_check and isinstance(res, (list, tuple)) and len(res) > 2:
104+
raise ValueError(
105+
f"The `padding` field of paddle.compat.Unfold can only have size 1 or 2, now len={len(res)}. \nDid you mean to use paddle.nn.Unfold() instead?"
106+
)
107+
return res
108+
109+
return nn.functional.unfold(
110+
input,
111+
kernel_sizes=to_list_if_necessary(self.kernel_sizes),
112+
strides=to_list_if_necessary(self.strides),
113+
paddings=to_list_if_necessary(self.paddings, size_check=True),
114+
dilations=to_list_if_necessary(self.dilations),
115+
name=self.name,
116+
)

python/paddle/utils/decorator_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,30 @@ def process(
131131
args = ()
132132

133133
return args, kwargs
134+
135+
136+
class ForbidKeywordsDecorator(DecoratorBase):
137+
"""A decorator that hints users to use the correct `compat` functions, when erroneous keyword arguments are detected"""
138+
139+
def __init__(
140+
self, illegal_keys: set[str], func_name: str, correct_name: str
141+
) -> None:
142+
super().__init__()
143+
self.illegal_keys = illegal_keys
144+
self.func_name = func_name
145+
self.correct_name = correct_name
146+
147+
def process(
148+
self, args: tuple[Any, ...], kwargs: dict[str, Any]
149+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
150+
found_keys = [key for key in self.illegal_keys if key in kwargs]
151+
152+
if found_keys:
153+
keys_str = ", ".join(f"'{key}'" for key in found_keys)
154+
plural = "s" if len(found_keys) > 1 else ""
155+
156+
raise TypeError(
157+
f"{self.func_name}() received unexpected keyword argument{plural} {keys_str}. "
158+
f"\nDid you mean to use {self.correct_name}() instead?"
159+
)
160+
return args, kwargs
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
21+
22+
class TestCompatUnfold(unittest.TestCase):
23+
def _compare_with_origin(
24+
self, input_tensor, kernel_size, dilation, padding, stride
25+
):
26+
unfold_compat = paddle.compat.Unfold(
27+
kernel_size=kernel_size,
28+
dilation=dilation,
29+
padding=padding,
30+
stride=stride,
31+
)
32+
unfold_origin = paddle.nn.Unfold(
33+
kernel_sizes=kernel_size,
34+
dilations=dilation,
35+
paddings=padding,
36+
strides=stride,
37+
)
38+
expected_res = unfold_origin(input_tensor).numpy()
39+
np.testing.assert_allclose(
40+
unfold_compat(input_tensor).numpy(), expected_res
41+
)
42+
43+
# test with tensor input
44+
to_tensor = lambda x: x if isinstance(x, int) else paddle.to_tensor(x)
45+
kernel_size = to_tensor(kernel_size)
46+
dilation = to_tensor(dilation)
47+
padding = to_tensor(padding)
48+
stride = to_tensor(stride)
49+
unfold_compat = paddle.compat.Unfold(
50+
kernel_size=kernel_size,
51+
dilation=dilation,
52+
padding=padding,
53+
stride=stride,
54+
)
55+
np.testing.assert_allclose(
56+
unfold_compat(input_tensor).numpy(), expected_res
57+
)
58+
59+
def test_compare_with_origin(self):
60+
input_shape = (3, 4, 5, 6)
61+
input_tensor = paddle.arange(360, dtype=paddle.float32).reshape(
62+
input_shape
63+
)
64+
self._compare_with_origin(input_tensor, [3, 3], [1, 1], (1, 2), [1, 1])
65+
66+
input_shape = (5, 10, 13, 13)
67+
input_tensor = paddle.ones(input_shape, dtype=paddle.float64)
68+
self._compare_with_origin(input_tensor, [4, 4], [2, 2], 1, (1, 2))
69+
70+
input_shape = (12, 4, 10, 10)
71+
input_tensor = paddle.ones(input_shape, dtype=paddle.float64)
72+
self._compare_with_origin(input_tensor, 3, 2, 1, (1, 1))
73+
74+
def test_error_handling(self):
75+
"""Test whether there will be correct exception when users pass paddle.split kwargs in paddle.compat.split, vice versa."""
76+
x = paddle.randn([3, 9, 5])
77+
78+
msg_gt_1 = "paddle.nn.Unfold() received unexpected keyword arguments 'dilation', 'stride'. \nDid you mean to use paddle.compat.Unfold() instead?"
79+
msg_gt_2 = "paddle.compat.Unfold() received unexpected keyword argument 'paddings'. \nDid you mean to use paddle.nn.Unfold() instead?"
80+
msg_gt_3 = "The `padding` field of paddle.compat.Unfold can only have size 1 or 2, now len=4. \nDid you mean to use paddle.nn.Unfold() instead?"
81+
msg_gt_4 = "paddle.compat.Unfold does not allow paddle.Tensor or pir.Value as inputs in static graph mode."
82+
83+
with self.assertRaises(TypeError) as cm:
84+
unfold = paddle.nn.Unfold([3, 3], dilation=[2, 2], stride=[1, 1])
85+
self.assertEqual(str(cm.exception), msg_gt_1)
86+
87+
with self.assertRaises(TypeError) as cm:
88+
unfold = paddle.compat.Unfold([3, 3], paddings=[2, 1])
89+
self.assertEqual(str(cm.exception), msg_gt_2)
90+
91+
with self.assertRaises(ValueError) as cm:
92+
unfold = paddle.compat.Unfold([3, 3], padding=[2, 1, 2, 2])
93+
res = unfold(paddle.ones([2, 2, 5, 5]))
94+
self.assertEqual(str(cm.exception), msg_gt_3)
95+
96+
with self.assertRaises(TypeError) as cm:
97+
paddle.enable_static()
98+
input_data = np.random.randn(2, 4, 8, 8).astype(np.float32)
99+
with paddle.static.program_guard(paddle.static.Program()):
100+
x = paddle.static.data(
101+
name='x', shape=[None, None, 8, 8], dtype='float32'
102+
)
103+
place = (
104+
paddle.CUDAPlace(0)
105+
if paddle.is_compiled_with_cuda()
106+
else paddle.CPUPlace()
107+
)
108+
unfold_pass = paddle.compat.Unfold(
109+
kernel_size=paddle.to_tensor([3, 3]),
110+
padding=paddle.to_tensor([1, 2]),
111+
)
112+
result = unfold_pass(x)
113+
exe = paddle.static.Executor(place)
114+
feed = {'x': input_data}
115+
exe_res = exe.run(feed=feed)
116+
paddle.disable_static()
117+
self.assertEqual(str(cm.exception), msg_gt_4)
118+
119+
120+
if __name__ == '__main__':
121+
unittest.main()

0 commit comments

Comments
 (0)