Skip to content

Commit 957bf79

Browse files
authored
[PHI] Migrate reshape kernel (PaddlePaddle#48749)
* reshape * typo * remove header
1 parent 0b53147 commit 957bf79

File tree

2 files changed

+179
-22
lines changed

2 files changed

+179
-22
lines changed

paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
9999
case ReshapeKernelOpName::reshape:
100100
InferShapeReshapeOp(ctx, x_dims, out_dims);
101101
break;
102-
case ReshapeKernelOpName::reshape2:
103-
InferShapeReshape2Op(ctx, x_dims, out_dims);
104-
break;
105102
case ReshapeKernelOpName::squeeze:
106103
InferShapeSqueezeOp(ctx, x_dims, out_dims);
107104
break;
@@ -127,17 +124,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
127124
ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims);
128125
}
129126

130-
void InferShapeReshape2Op(const framework::ExecutionContext& ctx,
131-
framework::DDim& x_dims, // NOLINT
132-
framework::DDim& out_dims) const { // NOLINT
133-
auto* out = ctx.Output<phi::DenseTensor>("Out");
134-
auto* xshape = ctx.Output<phi::DenseTensor>("XShape");
135-
auto xshape_dims = xshape->dims();
136-
x_dims = phi::slice_ddim(xshape_dims, 1, xshape_dims.size());
137-
out_dims = out->dims();
138-
ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims);
139-
}
140-
141127
// in reshape1/2 ops "ShapeTensor" has highest priority and "Shape" has
142128
// second highest priority
143129
void ChangeReshapeOutDimsIfNeeded(
@@ -400,14 +386,6 @@ REGISTER_OP_KERNEL(
400386
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
401387
ReshapeKernelOpName::reshape>);
402388

403-
REGISTER_OP_KERNEL(
404-
reshape2,
405-
MKLDNN,
406-
paddle::platform::CPUPlace,
407-
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::reshape2>,
408-
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
409-
ReshapeKernelOpName::reshape2>);
410-
411389
REGISTER_OP_KERNEL(
412390
reshape2_grad,
413391
MKLDNN,
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/phi/backends/onednn/onednn_reuse.h"
13+
#include "paddle/phi/core/kernel_registry.h"
14+
15+
namespace phi {
16+
17+
static DDim ValidateShape(const std::vector<int64_t>& shape,
18+
const DDim& in_dims) {
19+
const int64_t in_size = product(in_dims);
20+
auto in_dims_vec = vectorize(in_dims);
21+
bool all_positive = std::all_of(in_dims_vec.cbegin(),
22+
in_dims_vec.cend(),
23+
[](int64_t i) { return i > 0; });
24+
// only one dimension can be set to -1, whose size will be automatically
25+
// infered
26+
const int64_t unk_dim_val = -1;
27+
const int64_t copy_dim_val = 0;
28+
29+
std::vector<int64_t> output_shape(shape.size(), 0);
30+
int64_t capacity = 1;
31+
int unk_dim_idx = -1;
32+
for (size_t i = 0; i < shape.size(); ++i) {
33+
if (shape[i] == unk_dim_val) {
34+
PADDLE_ENFORCE_EQ(
35+
unk_dim_idx,
36+
-1,
37+
errors::InvalidArgument(
38+
"Only one dimension value of 'shape' in ReshapeOp can "
39+
"be -1. But received shape = [%s], shape[%d] is also -1.",
40+
make_ddim(shape),
41+
i));
42+
unk_dim_idx = i;
43+
} else if (shape[i] == copy_dim_val) {
44+
PADDLE_ENFORCE_LT(
45+
static_cast<int>(i),
46+
in_dims.size(),
47+
errors::InvalidArgument(
48+
"The index of 0 in `shape` must be less than "
49+
"the input tensor X's dimensions. "
50+
"But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
51+
"X's dimensions = %d.",
52+
make_ddim(shape),
53+
i,
54+
in_dims,
55+
in_dims.size()));
56+
} else {
57+
PADDLE_ENFORCE_GT(
58+
shape[i],
59+
0,
60+
errors::InvalidArgument(
61+
"Each dimension value of 'shape' in ReshapeOp must not "
62+
"be negative except one unknown dimension. "
63+
"But received shape = [%s], shape[%d] = %d.",
64+
make_ddim(shape),
65+
i,
66+
shape[i]));
67+
}
68+
69+
capacity *= (shape[i] ? shape[i] : in_dims[i]);
70+
output_shape[i] = (shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
71+
}
72+
73+
if (unk_dim_idx != -1) {
74+
if (all_positive) {
75+
// in_size < 0 and is un-determinate in compile time, skip the check,
76+
// for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
77+
// capacity = -24, in_size = -8, output_shape[0] = 0
78+
// the following check will fail.
79+
output_shape[unk_dim_idx] = -in_size / capacity;
80+
PADDLE_ENFORCE_EQ(
81+
output_shape[unk_dim_idx] * capacity,
82+
-in_size,
83+
errors::InvalidArgument(
84+
"The 'shape' attribute in ReshapeOp is invalid. "
85+
"The input tensor X'size must be divisible by known "
86+
"capacity of 'shape'. "
87+
"But received X's shape = [%s], X's size = %d, "
88+
"'shape' is [%s], known capacity of 'shape' is %d.",
89+
in_dims,
90+
in_size,
91+
make_ddim(shape),
92+
capacity));
93+
} else {
94+
output_shape[unk_dim_idx] = -1;
95+
}
96+
} else {
97+
if (all_positive) {
98+
PADDLE_ENFORCE_EQ(
99+
capacity,
100+
in_size,
101+
errors::InvalidArgument(
102+
"The 'shape' in ReshapeOp is invalid. "
103+
"The input tensor X'size must be equal to the capacity of "
104+
"'shape'. "
105+
"But received X's shape = [%s], X's size = %d, 'shape' is "
106+
"[%s], the capacity of 'shape' is %d.",
107+
in_dims,
108+
in_size,
109+
make_ddim(shape),
110+
capacity));
111+
}
112+
}
113+
return make_ddim(output_shape);
114+
}
115+
116+
template <typename T, typename Context>
117+
void ExecuteReshape(const Context& dev_ctx,
118+
const DenseTensor& x,
119+
const IntArray& shape,
120+
const DDim& x_dims,
121+
DenseTensor* out) {
122+
auto out_dims = ValidateShape(shape.GetData(), x_dims);
123+
auto x_vec_dims = vectorize(x_dims);
124+
125+
funcs::ReorderOneDNNHandler reorder_handler(
126+
x_vec_dims,
127+
x.dtype(),
128+
funcs::ToOneDNNDataType(x.dtype()),
129+
dev_ctx.GetEngine());
130+
131+
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
132+
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
133+
out->Resize(x_dims); // to match x numel, format is changed later
134+
// reorder is done into a plain tag to allow usage with blocked formats
135+
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
136+
out, funcs::GetPlainOneDNNFormat(x_dims.size()), dev_ctx.GetPlace());
137+
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
138+
reorder_src_memory_p);
139+
140+
auto& astream = OneDNNContext::tls().get_stream();
141+
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
142+
143+
astream.wait();
144+
145+
out->Resize(out_dims);
146+
out->set_mem_desc(
147+
reorder_dst_memory_p->get_desc().reshape(vectorize(out_dims)));
148+
}
149+
150+
template <typename T, typename Context>
151+
void ReshapeKernel(const Context& dev_ctx,
152+
const DenseTensor& x,
153+
const IntArray& shape,
154+
DenseTensor* out) {
155+
auto x_dims = x.dims();
156+
ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out);
157+
}
158+
159+
template <typename T, typename Context>
160+
void ReshapeWithXShape(const Context& dev_ctx,
161+
const DenseTensor& x,
162+
const IntArray& shape,
163+
DenseTensor* out,
164+
DenseTensor* xshape) {
165+
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
166+
ExecuteReshape<T, Context>(dev_ctx, x, shape, x_dims, out);
167+
}
168+
169+
} // namespace phi
170+
171+
PD_REGISTER_KERNEL(
172+
reshape, OneDNN, ONEDNN, phi::ReshapeKernel, float, phi::dtype::bfloat16) {}
173+
174+
PD_REGISTER_KERNEL(reshape_with_xshape,
175+
OneDNN,
176+
ONEDNN,
177+
phi::ReshapeWithXShape,
178+
float,
179+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)