-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Convolution operator #4042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Convolution operator #4042
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
c9d8cb4
Convolution op and forward calculation.
hedaoyuan 40fe0a8
Add backward of convolution.
hedaoyuan 3705de6
Merge branch 'develop' of https://github.com/baidu/Paddle into conv_op
hedaoyuan c671189
Fix test_conv2d_op.py.
hedaoyuan a7c1872
Refine test_conv2d_op.py
hedaoyuan 67db9d3
Refine the GemmConvKernel.
hedaoyuan db33ff1
Refine the GemmConvGradKernel.
hedaoyuan 5860150
Fix Tensor::Slice with dims[0] == 1.
hedaoyuan 8219f20
Refine gemm convolution kernel.
hedaoyuan 14ae805
Merge branch 'develop' of https://github.com/baidu/Paddle into conv_op
hedaoyuan fb46345
Add groups in convolution operator.
hedaoyuan 2340ced
Add groups in convolution GemmConvGradKernel.
hedaoyuan 1dd639e
Bug fix.
hedaoyuan b4ba35c
Add groups test.
hedaoyuan 656f775
Fix the doc.
hedaoyuan 7bf1e76
Merge branch 'develop' of https://github.com/baidu/Paddle into conv_op
hedaoyuan 09c65b6
Follow comments.
hedaoyuan 91afa0d
Some bug fix.
hedaoyuan 5a4138b
Add test with groups=1.
hedaoyuan 64b0b75
Follow comments fix conv2d_op.cc
hedaoyuan f3669ca
Support input_grad = null or filter_grad = null.
hedaoyuan 6c0129a
Refine the GemmConvGrad2DKernel.
hedaoyuan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/gemm_conv2d_op.h" | ||
|
|
||
| namespace paddle { | ||
| namespace operators { | ||
|
|
||
| int outputSize(int input_size, int filter_size, int padding, int stride) { | ||
| int output_size = (input_size - filter_size + 2 * padding) / stride + 1; | ||
| return output_size; | ||
| } | ||
|
|
||
| class Conv2DOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| protected: | ||
| void InferShape(const framework::InferShapeContext &ctx) const override { | ||
| PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), | ||
| "Input(Input) of Conv2DOp should not be null."); | ||
| PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"), | ||
| "Input(Filter) of Conv2DOp should not be null."); | ||
| PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), | ||
| "Output(Output) of Conv2DOp should not be null."); | ||
|
|
||
| auto in = ctx.Input<Tensor>("Input"); | ||
| auto filter = ctx.Input<Tensor>("Filter"); | ||
| auto out = ctx.Output<framework::LoDTensor>("Output"); | ||
| std::vector<int> strides = Attr<std::vector<int>>("strides"); | ||
| std::vector<int> paddings = Attr<std::vector<int>>("paddings"); | ||
| int groups = Attr<int>("groups"); | ||
| int input_channels = in->dims()[1]; | ||
| int output_channels = filter->dims()[0]; | ||
|
|
||
| PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D."); | ||
| PADDLE_ENFORCE_EQ(filter->dims().size(), 4, | ||
| "Conv2DOp filter should be 4-D."); | ||
| PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups, | ||
| "The number of input channels should be equal to filter " | ||
| "channels * groups."); | ||
| PADDLE_ENFORCE_EQ( | ||
| output_channels % groups, 0, | ||
| "The number of output channels should be divided by groups."); | ||
|
|
||
| auto output_height = | ||
| outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]); | ||
| auto output_width = | ||
| outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]); | ||
| out->Resize( | ||
| {in->dims()[0], filter->dims()[0], output_height, output_width}); | ||
| } | ||
| }; | ||
|
|
||
| class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) | ||
| : OpProtoAndCheckerMaker(proto, op_checker) { | ||
| AddInput( | ||
| "Input", | ||
| "The input tensor of convolution operator. " | ||
| "The format of input tensor is NCHW. Where N is batch size, C is the " | ||
| "number of channels, H and W is the height and width of image."); | ||
| AddInput( | ||
| "Filter", | ||
| "The filter tensor of convolution operator." | ||
| "The format of the filter tensor is MCHW, where M is the number of " | ||
| "output image channels, C is the number of input image channels, " | ||
| "H and W is height and width of filter. " | ||
| "If the groups attribute is greater than 1, C equal the number of " | ||
| "input image channels divided by the groups."); | ||
| AddOutput("Output", | ||
| "The output tensor of convolution operator." | ||
| "The format of output tensor is also NCHW."); | ||
| AddAttr<std::vector<int>>("strides", "strides of convolution operator.") | ||
| .SetDefault({1, 1}); | ||
| AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.") | ||
| .SetDefault({0, 0}); | ||
| AddAttr<int>( | ||
| "groups", | ||
| "group size of convolution operator. " | ||
| "Refer to grouped convolution in Alex Krizhevsky's paper: " | ||
| "when group=2, the first half of the filters are only connected to the " | ||
| "first half of the input channels, and the second half only connected " | ||
| "to the second half.") | ||
| .SetDefault(1); | ||
| AddComment(R"DOC( | ||
| The convolution operation calculates the output based on the input, filter | ||
| and strides, paddings, groups parameters. The size of each dimension of the | ||
| parameters is checked in the infer-shape. | ||
| )DOC"); | ||
| } | ||
| }; | ||
|
|
||
| class Conv2DOpGrad : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
|
||
| protected: | ||
| void InferShape(const framework::InferShapeContext &ctx) const override { | ||
| auto in = ctx.Input<Tensor>("Input"); | ||
| auto filter = ctx.Input<Tensor>("Filter"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add not-null check for Input and Output.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| auto d_in = | ||
| ctx.Output<framework::LoDTensor>(framework::GradVarName("Input")); | ||
| auto d_filter = | ||
| ctx.Output<framework::LoDTensor>(framework::GradVarName("Filter")); | ||
| if (d_in) d_in->Resize(in->dims()); | ||
| if (d_filter) d_filter->Resize(filter->dims()); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace operators | ||
| } // namespace paddle | ||
|
|
||
| namespace ops = paddle::operators; | ||
| REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad, | ||
| ops::Conv2DOpGrad); | ||
|
|
||
| REGISTER_OP_CPU_KERNEL( | ||
| conv2d, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>); | ||
| REGISTER_OP_CPU_KERNEL( | ||
| conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/operators/gemm_conv2d_op.h" | ||
|
|
||
| namespace ops = paddle::operators; | ||
|
|
||
| REGISTER_OP_GPU_KERNEL( | ||
| conv2d, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>); | ||
| REGISTER_OP_GPU_KERNEL( | ||
| conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put
AddAttrbeforeAddComment(R"DOC )DOC").There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.