Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@
'scatter_grad',
'scatter_nd_add_grad',
'slice_grad',
'squeeze_grad',
'tile_grad',
'topk_grad',
'unsqueeze_grad',
]

# whole vjp list of primitive op vjp
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,51 @@ void softmax_grad(const Tensor& out,
}
}

template <typename T>
void squeeze_grad(const Tensor& xshape,
const Tensor& out_grad,
const IntArray& axis,
Tensor* x_grad) {
if (x_grad) {
auto x_grad_out = unsqueeze<T>(out_grad, axis);
set_output<T>(x_grad_out, x_grad);
}
}

template <typename T>
void unsqueeze_grad(const Tensor& xshape,
const Tensor& out_grad,
const IntArray& axis,
Tensor* x_grad) {
// for xshape = [10, 2, 5], axis = [3, 1, 1], out_grad.shape = [10, 1, 1, 2,
// 5, 1], it outputs squeeze axis = [5, 2, 1]
const auto& IncreaseAxis = [](std::vector<int64_t>* axis_data,
int64_t pivot) {
for (size_t i = 0; i < axis_data->size(); ++i) {
if ((*axis_data)[i] >= pivot) (*axis_data)[i] += 1;
}
};
const auto& GetRealAxis = [&](const IntArray& axis) -> decltype(auto) {
// for axis = [0, 3, 3], it outputs [0, 3, 3+1], because unsqueeze support
// duplicated axis.
std::vector<int64_t> output_axis;
const int64_t x_rank = xshape.dims().size() - 1;
const std::vector<int64_t> axis_data = axis.GetData();
for (size_t i = 0; i < axis_data.size(); ++i) {
int64_t value = axis_data[i];
if (value < 0) value += (x_rank + i + 1);
IncreaseAxis(&output_axis, value);
output_axis.push_back(value);
}
return output_axis;
};

if (x_grad) {
auto x_grad_out = squeeze<T>(out_grad, GetRealAxis(axis));
set_output<T>(x_grad_out, x_grad);
}
}

template <typename T>
void matmul_grad(const Tensor& x,
const Tensor& y,
Expand Down
10 changes: 0 additions & 10 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4286,16 +4286,6 @@ void SqueezeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config) {
const auto& x_dims = x.dims();
// Check input tensor dims (<6) Eigen limit.
PADDLE_ENFORCE_LE(x_dims.size(),
6,
phi::errors::InvalidArgument(
"The dimensions of Input(X) "
"should be in the range of [1, 6] (Eigen limit)."
"But received X's dimensions = %d, X's shape = [%s].",
x_dims.size(),
x_dims));

if (!config.is_runtime && axes.FromTensor()) {
// compile time infershape, set all elements to -1.
int output_size = static_cast<int>(x.dims().size() - axes.GetData().size());
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def decompose(
blacklist (frozenset): The Operators that will be exclude when decomposed into primitives.
whitelist (frozenset): Only the operators in whitelist will be decomposed into primitives.
start_index (int): The start index of decomposed operator in global block, default 0;
end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed.
end_index (int): The end index of decomposed operator in global block, default -1 means all ops will be composed. start_index and end_index follow the principle of left closed and right open, that is [start_index, end_index).

Returns:
dst_vars (list): A list contains all vars which replace origin ones in src_vars.
Expand Down
3 changes: 1 addition & 2 deletions test/deprecated/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 120)
set_tests_properties(test_model PROPERTIES TIMEOUT 300)
set_tests_properties(test_pretrained_model PROPERTIES TIMEOUT 600)
set_tests_properties(test_squeeze2_op_rename PROPERTIES TIMEOUT 120)

if(APPLE)
set_tests_properties(test_callback_early_stop PROPERTIES TIMEOUT 300)
Expand All @@ -772,11 +773,9 @@ set(TEST_CINN_OPS
test_top_k_v2_op
test_elementwise_mul_op
test_gather_nd_op
test_squeeze2_op
test_elementwise_pow_op
test_transpose_op
test_reshape_op
test_unsqueeze2_op
test_meshgrid_op
test_scale_op
test_scatter_op
Expand Down
102 changes: 102 additions & 0 deletions test/deprecated/legacy_test/test_squeeze2_op_rename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

from test_attribute_var import UnittestBase

import paddle
from paddle.base.framework import Program, program_guard

paddle.enable_static()


class TestSqueeze2AxesTensor(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')

def test_static(self):
main_prog = Program()
startup_prog = Program()
with program_guard(main_prog, startup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10]
# axes is a Variable
axes = paddle.assign([0, 2])
out = paddle.squeeze(feat, axes)
out2 = paddle.squeeze(feat, axes)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
self.assertTrue("Var[" in str(main_prog))

exe = paddle.static.Executor()
exe.run(startup_prog)
res = exe.run(fetch_list=[feat, out, out2])
self.assertEqual(res[0].shape, (1, 2, 1, 3, 10))
self.assertEqual(res[1].shape, (2, 3, 10))
self.assertEqual(res[2].shape, (2, 3, 10))

paddle.static.save_inference_model(self.save_path, [x], [out], exe)
# Test for Inference Predictor
infer_out = self.infer_prog()
self.assertEqual(infer_out.shape, (2, 3, 10))


class TestSqueeze2AxesTensorList(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')

def test_static(self):
main_prog = Program()
startup_prog = Program()
with program_guard(main_prog, startup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10]
# axes is a list[Variable]
axes = [
paddle.full([1], 0, dtype='int32'),
paddle.full([1], 2, dtype='int32'),
]
out = paddle.squeeze(feat, axes)
out2 = paddle.squeeze(feat, axes)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
self.assertTrue("Vars[" in str(main_prog))

exe = paddle.static.Executor()
exe.run(startup_prog)
res = exe.run(fetch_list=[feat, out, out2])
self.assertEqual(res[0].shape, (1, 2, 1, 3, 10))
self.assertEqual(res[1].shape, (2, 3, 10))
self.assertEqual(res[2].shape, (2, 3, 10))

paddle.static.save_inference_model(self.save_path, [x], [out], exe)
# Test for Inference Predictor
infer_out = self.infer_prog()
self.assertEqual(infer_out.shape, (2, 3, 10))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/sub_graphs/test_sub_graph_59.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_ast_prim_cinn(self):
for st, cinn in zip(
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
):
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-8)
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6)


if __name__ == '__main__':
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,8 @@ set(TEST_CINN_OPS
test_group_norm_op
test_tile_op
test_sum_op
test_squeeze2_op
test_unsqueeze2_op
test_elementwise_min_op
test_take_along_axis_op
test_strided_slice_op
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest

import numpy as np
from op_test import OpTest, convert_float_to_uint16
from test_attribute_var import UnittestBase

import paddle
from paddle.base import core
from paddle.base.framework import Program, program_guard
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()
Expand All @@ -31,7 +28,7 @@
class TestSqueezeOp(OpTest):
def setUp(self):
self.op_type = "squeeze2"
self.prim_op_type = "comp"
self.prim_op_type = "prim"
self.python_api = paddle.squeeze
self.public_python_api = paddle.squeeze
self.python_out_sig = [
Expand All @@ -58,7 +55,6 @@ def if_enable_cinn(self):
def test_check_output(self):
self.check_output(
no_check_set=['XShape'],
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
Expand All @@ -67,7 +63,6 @@ def test_check_grad(self):
self.check_grad(
["X"],
"Out",
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
Expand Down Expand Up @@ -191,81 +186,6 @@ def init_dtype(self):
self.dtype = np.uint16


class TestSqueeze2AxesTensor(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')

def test_static(self):
main_prog = Program()
startup_prog = Program()
with program_guard(main_prog, startup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10]
# axes is a Variable
axes = paddle.assign([0, 2])
out = paddle.squeeze(feat, axes)
out2 = paddle.squeeze(feat, axes)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
self.assertTrue("Var[" in str(main_prog))

exe = paddle.static.Executor()
exe.run(startup_prog)
res = exe.run(fetch_list=[feat, out, out2])
self.assertEqual(res[0].shape, (1, 2, 1, 3, 10))
self.assertEqual(res[1].shape, (2, 3, 10))
self.assertEqual(res[2].shape, (2, 3, 10))

paddle.static.save_inference_model(self.save_path, [x], [out], exe)
# Test for Inference Predictor
infer_out = self.infer_prog()
self.assertEqual(infer_out.shape, (2, 3, 10))


class TestSqueeze2AxesTensorList(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, 'squeeze_tensor')

def test_static(self):
main_prog = Program()
startup_prog = Program()
with program_guard(main_prog, startup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
feat = paddle.unsqueeze(feat, [0, 2]) # [1, 2, 3, 1, 10]
# axes is a list[Variable]
axes = [
paddle.full([1], 0, dtype='int32'),
paddle.full([1], 2, dtype='int32'),
]
out = paddle.squeeze(feat, axes)
out2 = paddle.squeeze(feat, axes)

sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
self.assertTrue("Vars[" in str(main_prog))

exe = paddle.static.Executor()
exe.run(startup_prog)
res = exe.run(fetch_list=[feat, out, out2])
self.assertEqual(res[0].shape, (1, 2, 1, 3, 10))
self.assertEqual(res[1].shape, (2, 3, 10))
self.assertEqual(res[2].shape, (2, 3, 10))

paddle.static.save_inference_model(self.save_path, [x], [out], exe)
# Test for Inference Predictor
infer_out = self.infer_prog()
self.assertEqual(infer_out.shape, (2, 3, 10))


# test api
class TestSqueezeAPI(unittest.TestCase):
def setUp(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,7 @@ def setUp(self):
"Out": self.inputs["X"].reshape(self.new_shape),
"XShape": np.random.random(self.ori_shape).astype("float64"),
}
self.prim_op_type = "comp"
self.prim_op_type = "prim"
self.if_enable_cinn()

def if_enable_cinn(self):
Expand All @@ -46,7 +46,6 @@ def if_enable_cinn(self):
def test_check_output(self):
self.check_output(
no_check_set=["XShape"],
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
Expand All @@ -55,7 +54,6 @@ def test_check_grad(self):
self.check_grad(
["X"],
"Out",
check_prim=True,
check_pir=True,
check_prim_pir=True,
)
Expand Down