Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/phi/kernels/flatten_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ void FlattenInferStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);
DenseTensor* out,
DenseTensor* xshape);

template <typename Context>
void FlattenStridedKernel(const Context& dev_ctx,
Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/stride/flatten_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ void FlattenInferStridedKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis UNUSED,
int stop_axis UNUSED,
DenseTensor* out) {
DenseTensor* out,
DenseTensor* xshape) {
ReshapeStridedKernel<Context>(
dev_ctx,
x,
IntArray(common::vectorize<int64_t>(out->dims())),
out,
nullptr);
xshape);
}

template <typename Context>
Expand All @@ -38,8 +39,9 @@ void FlattenStridedKernel(const Context& dev_ctx,
int start_axis,
int stop_axis,
DenseTensor* out,
DenseTensor* xshape UNUSED) {
FlattenInferStridedKernel<Context>(dev_ctx, x, start_axis, stop_axis, out);
DenseTensor* xshape) {
FlattenInferStridedKernel<Context>(
dev_ctx, x, start_axis, stop_axis, out, xshape);
}

} // namespace phi
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/stride/reshape_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ void ReshapeStridedKernel(const Context& dev_ctx,
size_t x_offset = x.offset();
if (xshape) {
x_dims = DDim(xshape->dims().Get() + 1, xshape->dims().size() - 1);
x_stride = xshape->strides();
}
MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape.GetData(), &meta_out);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
expand,
expand_as,
flatten,
flatten_,
flip,
flip as reverse,
gather,
Expand Down Expand Up @@ -881,6 +882,7 @@
'set_printoptions',
'std',
'flatten',
'flatten_',
'asin',
'multiply',
'multiply_',
Expand Down