From abd069cb28b7cfe90709253cf3dbfce38428ce5f Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Thu, 28 Mar 2024 10:31:13 +0000 Subject: [PATCH 1/5] fix flatten_ stride calcaulate error & add paddle.flatten_ --- paddle/phi/kernels/flatten_kernel.h | 3 ++- paddle/phi/kernels/stride/flatten_kernel.cc | 11 +++++++---- python/paddle/__init__.py | 2 ++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/flatten_kernel.h b/paddle/phi/kernels/flatten_kernel.h index b941a1fbb96910..ac53c5b82c6cb9 100644 --- a/paddle/phi/kernels/flatten_kernel.h +++ b/paddle/phi/kernels/flatten_kernel.h @@ -40,7 +40,8 @@ void FlattenInferStridedKernel(const Context& dev_ctx, const DenseTensor& x, int start_axis, int stop_axis, - DenseTensor* out); + DenseTensor* out, + DenseTensor* xshape); template void FlattenStridedKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/stride/flatten_kernel.cc b/paddle/phi/kernels/stride/flatten_kernel.cc index f2240aa9bff877..717e6ccdaec0bb 100644 --- a/paddle/phi/kernels/stride/flatten_kernel.cc +++ b/paddle/phi/kernels/stride/flatten_kernel.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/flatten_kernel.h" +#include "glog/logging.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/reshape_kernel.h" @@ -23,13 +24,14 @@ void FlattenInferStridedKernel(const Context& dev_ctx, const DenseTensor& x, int start_axis UNUSED, int stop_axis UNUSED, - DenseTensor* out) { + DenseTensor* out, + DenseTensor* xshape) { ReshapeStridedKernel( dev_ctx, x, IntArray(common::vectorize(out->dims())), out, - nullptr); + xshape); } template @@ -38,8 +40,9 @@ void FlattenStridedKernel(const Context& dev_ctx, int start_axis, int stop_axis, DenseTensor* out, - DenseTensor* xshape UNUSED) { - FlattenInferStridedKernel(dev_ctx, x, start_axis, stop_axis, out); + DenseTensor* xshape) { + FlattenInferStridedKernel( + dev_ctx, x, start_axis, stop_axis, out, xshape); } } // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 05cff990c18379..ab4d932278093f 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -278,6 +278,7 @@ expand, expand_as, flatten, + flatten_, flip, flip as reverse, gather, @@ -881,6 +882,7 @@ 'set_printoptions', 'std', 'flatten', + 'flatten_', 'asin', 'multiply', 'multiply_', From ee313354dd8e1d29d83e7c192812c25ec79d3197 Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Thu, 28 Mar 2024 10:34:25 +0000 Subject: [PATCH 2/5] remove log --- paddle/phi/kernels/stride/flatten_kernel.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/phi/kernels/stride/flatten_kernel.cc b/paddle/phi/kernels/stride/flatten_kernel.cc index 717e6ccdaec0bb..074b4bbf613233 100644 --- a/paddle/phi/kernels/stride/flatten_kernel.cc +++ b/paddle/phi/kernels/stride/flatten_kernel.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/flatten_kernel.h" -#include "glog/logging.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/reshape_kernel.h" From 22367f400058516f8995bc66633c13f47093d0fc Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Fri, 29 Mar 2024 07:01:03 +0000 Subject: [PATCH 3/5] fix stride error --- paddle/phi/kernels/stride/reshape_kernel.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/stride/reshape_kernel.cc b/paddle/phi/kernels/stride/reshape_kernel.cc index 02d36d825c36aa..6db444662fcff0 100644 --- a/paddle/phi/kernels/stride/reshape_kernel.cc +++ b/paddle/phi/kernels/stride/reshape_kernel.cc @@ -14,6 +14,7 @@ #include "paddle/phi/kernels/reshape_kernel.h" #include +#include "glog/logging.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/strided_reshape_utils.h" @@ -31,7 +32,6 @@ void ReshapeStridedKernel(const Context& dev_ctx, size_t x_offset = x.offset(); if (xshape) { x_dims = DDim(xshape->dims().Get() + 1, xshape->dims().size() - 1); - x_stride = xshape->strides(); } MetaTensor meta_out(out); InferMetaFromVecValue(x, shape.GetData(), &meta_out); From 4bf63028bf617f96399c7b909d7da50780ff7d2d Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Fri, 29 Mar 2024 08:09:56 +0000 Subject: [PATCH 4/5] rm glog --- paddle/phi/kernels/stride/reshape_kernel.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/phi/kernels/stride/reshape_kernel.cc b/paddle/phi/kernels/stride/reshape_kernel.cc index 6db444662fcff0..7cfb20430beb59 100644 --- a/paddle/phi/kernels/stride/reshape_kernel.cc +++ b/paddle/phi/kernels/stride/reshape_kernel.cc @@ -14,7 +14,6 @@ #include "paddle/phi/kernels/reshape_kernel.h" #include -#include "glog/logging.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/strided_reshape_utils.h" From f3f09c11b2d6f5e5f820d866c24ea03f99cda04e Mon Sep 17 00:00:00 2001 From: GGBond8488 <857631483@qq.com> Date: Sun, 7 Apr 2024 11:12:49 +0000 Subject: [PATCH 5/5] add stride error testcase --- test/legacy_test/test_inplace.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index 83652bfa2f2884..7655afb82cc042 100755 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -438,6 +438,18 @@ def inplace_api_processing(self, var): return var.flatten_() +class TestDygraphInplaceFlattenStride(TestDygraphInplace): + def init_data(self): + self.input_var_numpy = np.random.randn(2, 3, 2) + self.dtype = "float32" + + def non_inplace_api_processing(self, var): + return var.flatten(0, 1) + + def inplace_api_processing(self, var): + return var.flatten_(0, 1) + + class TestDygraphInplaceScatter(TestDygraphInplace): def init_data(self): self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]])