Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
013d0a2
add crop layer
wanghaoshuang Jun 16, 2017
90ed200
Refine configure option of crop layer
wanghaoshuang Jun 22, 2017
701827f
Add grad test and python wrapper for crop layer
wanghaoshuang Jul 4, 2017
cbd61c7
fix crop function test
wanghaoshuang Jul 5, 2017
e10040c
add crop layer
wanghaoshuang Jun 16, 2017
d1d70ec
Refine configure option of crop layer
wanghaoshuang Jun 22, 2017
5e6e1f6
Add grad test and python wrapper for crop layer
wanghaoshuang Jul 4, 2017
86bdb2f
fix crop function test
wanghaoshuang Jul 5, 2017
cf86891
fix unittest of crop layer
wanghaoshuang Jul 5, 2017
470af1d
Merge branch 'crop_layer' of https://github.com/wanghaoshuang/Paddle …
wanghaoshuang Jul 5, 2017
acfd2fc
fix cpp format
wanghaoshuang Jul 5, 2017
d378e0a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang Jul 5, 2017
0b788ef
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang Jul 6, 2017
69b1222
fix crop layer python wrapper bug
wanghaoshuang Jul 11, 2017
de5ded6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang Jul 11, 2017
3e7819c
1. Reading image shape from input data instead of image_config
wanghaoshuang Jul 19, 2017
d83bae8
Merge branch 'develop' into crop_layer
wanghaoshuang Jul 19, 2017
60a7889
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
wanghaoshuang Jul 19, 2017
2e58f2c
Merge branch 'crop_layer' of https://github.com/wanghaoshuang/Paddle …
wanghaoshuang Jul 19, 2017
676b76d
fix cmake
wanghaoshuang Jul 19, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ if(WITH_GPU)
add_simple_unittest(MulOpTest)
add_simple_unittest(CosSimOpTest)
add_simple_unittest(RowConvOpTest)
add_simple_unittest(CropOpTest)
endif()

add_simple_unittest(ConvOpTest)
Expand Down
171 changes: 171 additions & 0 deletions paddle/function/CropOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "CropOp.h"
#include "paddle/function/TensorShape.h"
#include "paddle/math/Vector.h"

namespace paddle {

template <>
void Crop<DEVICE_TYPE_CPU>(real* outputs,
const real* inputs,
const TensorShape inShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner =
conf.get<std::vector<uint32_t>>("crop_corner");
std::vector<uint32_t> crop_shape =
conf.get<std::vector<uint32_t>>("crop_shape");
int cCrop = crop_corner[1];
int hCrop = crop_corner[2];
int wCrop = crop_corner[3];

int num = inShape[0];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];

int outC = crop_shape[1];
int outH = crop_shape[2];
int outW = crop_shape[3];

for (int n = 0; n < num; n++) {
for (int c = 0; c < outC; c++) {
for (int h = 0; h < outH; h++) {
int outoff = ((n * outC + c) * outH + h) * outW;
int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop;
memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real));
}
}
}
}

template <>
void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
real* outGrad,
const TensorShape outShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner =
conf.get<std::vector<uint32_t>>("crop_corner");
std::vector<uint32_t> crop_shape =
conf.get<std::vector<uint32_t>>("crop_shape");
int cCrop = crop_corner[1];
int hCrop = crop_corner[2];
int wCrop = crop_corner[3];

int num = outShape[0];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];

int inC = crop_shape[1];
int inH = crop_shape[2];
int inW = crop_shape[3];

for (int n = 0; n < num; n++) {
for (int c = 0; c < inC; c++) {
for (int h = 0; h < inH; h++) {
int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop;
int inoff = ((n * inC + c) * inH + h) * inW;
CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff));
CpuVector outG = CpuVector(inW, outGrad + outoff);
outG += inG;
}
}
}
}

/**
* \brief Crop input according to the specify corner and shape.
* The input and output is a 4D tensor. In CropFunc, we only
* crop the 2nd to 4th dimension.
*
* Argument in this Function:
* \param pad_ A struct object contains the cropping corner and shape.
* \param inputs A 4D tensor, only one input.
* \param outputs A 4D tensor, the output value after cropping.
*
* For example,
* Input(2,2,2,3) = [
* [ [[1,2,3], [3,4,5]],
* [[2,3,5], [1,6,7]] ],
* [ [[4,3,1], [1,8,7]],
* [[3,8,9], [2,3,5]] ]
* ] # the input shape is (2,2,2,3)
*
* pad_: if corner = (0,1,1) and crop_shape = (2,1,2)
* Output(2,2,1,2) = [
* [ [[4,5]],
* [[6,7]] ],
* [ [[8,7]],
* [[3,5]] ]
* ] # the input shape is (2,2,2,3)
*/
template <DeviceType Device>
class CropFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override { conf_ = config; }

void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);

TensorShape inShape = inputs[0].shape();

Crop<Device>(
outputs[0].data<real>(), inputs[0].data<real>(), inShape, conf_);
}

private:
FuncConfig conf_;
};

/**
* \brief The backward propagation of cropping Function.
*
* Argument in this Function:
* \param crop_ The same meaning as it in CropFunc.
* \param inputs The gradient with respect to the output value of CropFunc.
* \param outputs The gradient with respect to the input value of CropFunc.
*/

template <DeviceType Device>
class CropGradFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override { conf_ = config; }

void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ADD_TO);

TensorShape outShape = outputs[0].shape();

CropGrad<Device>(
inputs[0].data<real>(), outputs[0].data<real>(), outShape, conf_);
}

private:
FuncConfig conf_;
};

REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
#endif

} // namespace paddle
49 changes: 49 additions & 0 deletions paddle/function/CropOp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "Function.h"

namespace paddle {

/**
* \brief This funtion crops inputs according to the specify start point and
*shape.
*
* \param[out] outputs save results.
* \param[in] inputs input data.
* \param[in] inShape the shape of input tensor.
* \param[in] conf the cropping config
*/
template <DeviceType Device>
void Crop(real* outputs,
const real* inputs,
const TensorShape inShape,
const FuncConfig& conf);

/**
* \brief Cropping operation backward.
*
* \param[out] inGrad gradients of previous layer
* \param[in] outGrad output gradient
* \param[in] inShape the shape of input tensor.
* \param[in] conf the cropping config
*/
template <DeviceType Device>
void CropGrad(const real* inGrad,
real* outGrad,
const TensorShape inShape,
const FuncConfig& conf);
} // namespace paddle
113 changes: 113 additions & 0 deletions paddle/function/CropOpGpu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
#include "CropOp.h"

namespace paddle {

__global__ void KeCrop(real* outputs, const real* inputs,
int inC, int inH, int inW,
int cropC, int cropH, int cropW,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % outW;
const int h = (idx / outW) % outH;
const int c = (idx / outW / outH) % outC;
const int n = idx / outW / outH / outC;

const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w;
outputs[idx] = inputs[off];
}
}

template <>
void Crop<DEVICE_TYPE_GPU>(real* outputs,
const real* inputs,
const TensorShape inShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
std::vector<uint32_t> crop_shape = conf.get<std::vector<uint32_t>>("crop_shape");
int cropC = crop_corner[1];
int cropH = crop_corner[2];
int cropW = crop_corner[3];

int num = inShape[0];
int inC = inShape[1];
int inH = inShape[2];
int inW = inShape[3];

int outC = crop_shape[1];
int outH = crop_shape[2];
int outW = crop_shape[3];

size_t nth = num * outC * outH * outW;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;

KeCrop<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(outputs, inputs, inC, inH, inW, cropC, cropH, cropW,
outC, outH, outW, nth);
CHECK_SYNC("Crop");
}

__global__ void KeCropDiff(const real* inGrad, real* outGrad,
int inC, int inH, int inW,
int cropC, int cropH, int cropW,
int outC, int outH, int outW, int nthreads) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < nthreads) {
const int w = idx % inW;
const int h = (idx / inW) % inH;
const int c = (idx / inW / inH) % inC;
const int n = idx / inW / inH / inC;

const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w;

outGrad[off] += inGrad[idx];
}
}

template <>
void CropGrad<DEVICE_TYPE_GPU>(const real* inGrad,
real* outGrad,
const TensorShape outShape,
const FuncConfig& conf) {
std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
std::vector<uint32_t> crop_shape = conf.get<std::vector<uint32_t>>("crop_shape");
int cropC = crop_corner[1];
int cropH = crop_corner[2];
int cropW = crop_corner[3];

int num = outShape[0];
int outC = outShape[1];
int outH = outShape[2];
int outW = outShape[3];

int inC = crop_shape[1];
int inH = crop_shape[2];
int inW = crop_shape[3];

size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + blockSize - 1) / blockSize;

KeCropDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
(inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW,
outC, outH, outW, nth);
CHECK_SYNC("CropGrad");
}

} // namespace paddle
49 changes: 49 additions & 0 deletions paddle/function/CropOpTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include "FunctionTest.h"

namespace paddle {

TEST(Crop, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {5, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
for (bool test_grad : {false, true}) {
CpuGpuFuncCompare compare(
test_grad ? "CropGrad" : "Crop",
FuncConfig()
.set<std::vector<uint32_t>>("crop_corner", {0, 1, 1, 1})
.set<std::vector<uint32_t>>("crop_shape", {0, 2, 3, 3}));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{numSamples, 2, 3, 3};
compare.addInputs(
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT,
test_grad ? inDims : outDims,
test_grad ? ADD_TO : ASSIGN_TO),
test_grad ? ADD_TO : ASSIGN_TO);
compare.run();
}
}
}
}
}
}

} // namespace paddle
Loading