Skip to content

Commit 8b1813c

Browse files
committed
Merge commit 'refs/pull/59798/head' of github.com:PaddlePaddle/Paddle into sink
2 parents 96fd50e + 3f21040 commit 8b1813c

4 files changed

Lines changed: 29 additions & 33 deletions

File tree

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"sqrt",
3535
"squeeze",
3636
"stack",
37+
"unsqueeze",
3738
]
3839

3940
# come into effect in generated file op_decomp.cc
@@ -52,7 +53,9 @@
5253
"softmax",
5354
"sqrt",
5455
"squeeze",
56+
"unsqueeze",
5557
"stack",
58+
"unsqueeze",
5659
]
5760

5861

paddle/fluid/primitive/composite/composite.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -246,14 +246,8 @@ Tensor softmax_decomp(const Tensor& x, const int& axis) {
246246

247247
template <typename T>
248248
Tensor stack_decomp(const std::vector<Tensor>& x, const int& axis) {
249-
auto tensor_dims = x[0].dims();
250-
int tmp_axis = axis;
251-
if (tmp_axis < 0) {
252-
tmp_axis += tensor_dims.size() + 1;
253-
}
254-
255-
auto out_shape = phi::vectorize(tensor_dims);
256-
out_shape.insert(out_shape.begin() + tmp_axis, 1);
249+
std::vector<int64_t> axis_tmp = {axis};
250+
auto out_shape = get_expand_dims(x[0], axis_tmp);
257251

258252
std::vector<Tensor> concat_x;
259253
for (size_t i = 0; i < x.size(); ++i) {
@@ -318,6 +312,15 @@ std::tuple<Tensor, Tensor> squeeze_decomp(const Tensor& x,
318312
return std::make_tuple(out, xshape);
319313
}
320314

315+
template <typename T>
316+
std::tuple<Tensor, Tensor> unsqueeze_decomp(const Tensor& x,
317+
const IntArray& axis) {
318+
auto out_shape = get_expand_dims(x, axis.GetData());
319+
Tensor out = reshape<T>(x, out_shape);
320+
Tensor xshape;
321+
return std::make_tuple(out, xshape);
322+
}
323+
321324
template <typename T>
322325
Tensor add_n_decomp(const std::vector<Tensor>& x) {
323326
Tensor res = x[0];

paddle/fluid/primitive/utils/utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ static bool is_half_dtype(const DataType& dtype) {
3939
}
4040
}
4141

42+
// This function expands the dimension of origin Tensor based on the value of
43+
// axis
44+
static std::vector<int64_t> get_expand_dims(const Tensor& origin,
45+
const std::vector<int64_t>& axis) {
46+
std::vector<int64_t> result(origin.shape());
47+
for (size_t i = 0; i < axis.size(); ++i) {
48+
int64_t offset = axis[i];
49+
if (offset < 0) {
50+
offset += result.size() + 1;
51+
}
52+
result.insert(result.begin() + offset, 1);
53+
}
54+
return result;
55+
}
56+
4257
// This fucction compute unsqueeze dims for reshape to replace unsqueeze.
4358
static std::vector<int64_t> get_unsqueeze_dims(
4459
const Tensor& origin, const std::vector<int64_t>& axis) {

python/paddle/decomposition/rules.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
16-
from .primitives import * # noqa: F403
17-
from .register import register_decomp
18-
19-
20-
@register_decomp('pd_op.unsqueeze')
21-
def unsqueeze(x, axis):
22-
"""define composite rule of op unsqueeze"""
23-
"""using reshape to implement unsqueeze op"""
24-
axis = axis.get_defining_op().attrs()["value"]
25-
x_shape = list(x.shape)
26-
axis_list = list(axis)
27-
for i in axis_list:
28-
if i < 0:
29-
i += len(x_shape) + 1
30-
x_shape = (
31-
x_shape[:i]
32-
+ [
33-
1,
34-
]
35-
+ x_shape[i:]
36-
)
37-
out = reshape(x, x_shape)
38-
return [out, None]

0 commit comments

Comments
 (0)