Skip to content

Commit 4137cb0

Browse files
authored
Merge pull request #3949 from kuke/reshape_op_dev
Add reshape operator
2 parents 104ed75 + 5915138 commit 4137cb0

6 files changed

Lines changed: 208 additions & 0 deletions

File tree

paddle/operators/reshape_op.cc

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
2+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License. */
15+
16+
#include "paddle/operators/reshape_op.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ReshapeOp : public framework::OperatorWithKernel {
22+
public:
23+
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
24+
const framework::VariableNameMap &outputs,
25+
const framework::AttributeMap &attrs)
26+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
27+
28+
protected:
29+
void InferShape(const framework::InferShapeContext &ctx) const override {
30+
// input check
31+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null");
32+
auto shape = ctx.Attr<std::vector<int>>("shape");
33+
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
34+
for (auto dim : shape) {
35+
PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive.");
36+
}
37+
// capacity check
38+
int64_t capacity =
39+
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
40+
auto *in = ctx.Input<framework::Tensor>("X");
41+
int64_t in_size = framework::product(in->dims());
42+
PADDLE_ENFORCE_EQ(capacity, in_size,
43+
"The size of Input(X) mismatches with Attr(shape).");
44+
// resize output
45+
std::vector<int64_t> shape_int64(shape.size(), 0);
46+
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
47+
[](int a) { return static_cast<int64_t>(a); });
48+
auto out_dims = framework::make_ddim(shape_int64);
49+
ctx.Output<framework::Tensor>("Out")->Resize(out_dims);
50+
}
51+
};
52+
53+
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
54+
public:
55+
ReshapeOpMaker(framework::OpProto *proto,
56+
framework::OpAttrChecker *op_checker)
57+
: OpProtoAndCheckerMaker(proto, op_checker) {
58+
AddInput("X", "The input tensor of reshape operator.");
59+
AddOutput("Out", "The output tensor of reshape operator.");
60+
AddAttr<std::vector<int>>("shape", "Target shape of reshape operator.");
61+
AddComment(R"DOC(Reshape operator
62+
63+
Reshape Input(X) into the shape specified by Attr(shape).
64+
65+
An example:
66+
Given a 2-D tensor X with 2 rows and 2 columns
67+
68+
[[1, 2], [3, 4]]
69+
70+
with target shape = [1, 4], the reshape operator will transform
71+
the tensor X into a 1-D tensor:
72+
73+
[1, 2, 3, 4]
74+
75+
)DOC");
76+
}
77+
};
78+
79+
class ReshapeGradOp : public framework::OperatorWithKernel {
80+
public:
81+
ReshapeGradOp(const std::string &type,
82+
const framework::VariableNameMap &inputs,
83+
const framework::VariableNameMap &outputs,
84+
const framework::AttributeMap &attrs)
85+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
86+
87+
protected:
88+
void InferShape(const framework::InferShapeContext &ctx) const override {
89+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) shouldn't be null.");
90+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
91+
"Input(Out@GRAD) shouldn't be null.");
92+
auto dims = ctx.Input<framework::Tensor>("X")->dims();
93+
auto *d_in = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
94+
d_in->Resize(dims);
95+
}
96+
};
97+
98+
} // namespace operators
99+
} // namespace paddle
100+
namespace ops = paddle::operators;
101+
102+
REGISTER_OP(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, reshape_grad,
103+
ops::ReshapeGradOp);
104+
REGISTER_OP_CPU_KERNEL(reshape,
105+
ops::ReshapeKernel<paddle::platform::CPUPlace, float>);
106+
REGISTER_OP_CPU_KERNEL(
107+
reshape_grad, ops::ReshapeGradKernel<paddle::platform::CPUPlace, float>);

paddle/operators/reshape_op.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/reshape_op.h"
16+
17+
REGISTER_OP_GPU_KERNEL(
18+
reshape,
19+
paddle::operators::ReshapeKernel<paddle::platform::GPUPlace, float>);
20+
REGISTER_OP_GPU_KERNEL(
21+
reshape_grad,
22+
paddle::operators::ReshapeGradKernel<paddle::platform::GPUPlace, float>);

paddle/operators/reshape_op.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License. */
15+
16+
#pragma once
17+
18+
#include "paddle/framework/eigen.h"
19+
#include "paddle/framework/op_registry.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename Place, typename T>
25+
class ReshapeKernel : public framework::OpKernel {
26+
public:
27+
void Compute(const framework::ExecutionContext& ctx) const {
28+
auto* out = ctx.Output<framework::Tensor>("Out");
29+
auto* in = ctx.Input<framework::Tensor>("X");
30+
out->mutable_data<T>(ctx.GetPlace());
31+
32+
auto shape = ctx.Attr<std::vector<int>>("shape");
33+
std::vector<int64_t> shape_int64(shape.size(), 0);
34+
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
35+
[](int a) { return static_cast<int64_t>(a); });
36+
auto out_dims = framework::make_ddim(shape_int64);
37+
out->CopyFrom<T>(*in, ctx.GetPlace());
38+
out->Resize(out_dims);
39+
}
40+
};
41+
42+
template <typename Place, typename T>
43+
class ReshapeGradKernel : public framework::OpKernel {
44+
public:
45+
void Compute(const framework::ExecutionContext& ctx) const {
46+
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
47+
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
48+
d_x->mutable_data<T>(ctx.GetPlace());
49+
50+
auto in_dims = d_x->dims();
51+
d_x->CopyFrom<T>(*d_out, ctx.GetPlace());
52+
d_x->Resize(in_dims);
53+
}
54+
};
55+
}
56+
}

paddle/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ USE_CPU_ONLY_OP(concat);
5454
USE_OP(top_k);
5555
USE_OP(squared_l2_distance);
5656
USE_OP(sum);
57+
USE_OP(reshape);
5758

5859
namespace paddle {
5960
namespace framework {

python/paddle/v2/framework/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,4 @@ py_test(test_sum_op SRCS test_sum_op.py)
3535
py_test(mnist SRCS mnist.py)
3636
py_test(test_concat_op SRCS test_concat_op.py)
3737
py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py)
38+
py_test(test_reshape_op SRCS test_reshape_op.py)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
import numpy as np
3+
from op_test import OpTest
4+
5+
6+
class TestReshapeOp(OpTest):
7+
def setUp(self):
8+
self.op_type = "reshape"
9+
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
10+
self.attrs = {'shape': [10 * 20]}
11+
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
12+
13+
def test_check_output(self):
14+
self.check_output()
15+
16+
def test_check_grad(self):
17+
self.check_grad(["X"], "Out")
18+
19+
20+
if __name__ == '__main__':
21+
unittest.main()

0 commit comments

Comments
 (0)