Skip to content

Commit a096c58

Browse files
authored
Merge pull request #7034 from hedaoyuan/convolution
GemmConvMobileFunction(optimized for mobile)
2 parents d00e1ed + b7c4b58 commit a096c58

File tree

3 files changed

+290
-3
lines changed

3 files changed

+290
-3
lines changed

paddle/function/GemmConvOp.cpp

Lines changed: 158 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,165 @@ class GemmConvFunction : public ConvFunctionBase {
126126
inputData += inputChannels * inputHeight * inputWidth;
127127
outputData += outputChannels * outputHeight * outputWidth;
128128
}
129+
}
130+
};
131+
129132
#ifdef PADDLE_MOBILE_INFERENCE
130-
if (Device == DEVICE_TYPE_CPU) {
131-
memory_.reset();
133+
134+
/*
135+
* \brief Forward calculation of convolution, optimized for mobile.
136+
*/
137+
template <DeviceType Device>
138+
class GemmConvMobileFunction : public ConvFunctionBase {
139+
public:
140+
void init(const FuncConfig& config) override {
141+
ConvFunctionBase::init(config);
142+
}
143+
144+
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
145+
const TensorShape& input = inputs[0].shape();
146+
const TensorShape& filter = inputs[1].shape();
147+
const TensorShape& output = outputs[0].shape();
148+
checkShape(input, filter, output);
149+
}
150+
151+
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
152+
CHECK_EQ(numInputs_, inputs.size());
153+
CHECK_EQ(numOutputs_, outputs.size());
154+
check(inputs, outputs);
155+
// TODO(hedaoyuan): Need to define some index macros,
156+
// to avoid useing 0 and 1.
157+
const TensorShape& input = inputs[0].shape();
158+
const TensorShape& filter = inputs[1].shape();
159+
const TensorShape& output = outputs[0].shape();
160+
161+
real beta;
162+
if (outputs[0].getArgType() == ADD_TO) {
163+
beta = 1.0;
164+
} else {
165+
beta = 0.0;
132166
}
133-
#endif
167+
168+
size_t batchSize = input[0];
169+
size_t inputChannels = input[1];
170+
size_t inputHeight = input[2];
171+
size_t inputWidth = input[3];
172+
size_t filterHeight = getFilterHeight(filter);
173+
size_t filterWidth = getFilterWidth(filter);
174+
size_t outputChannels = output[1];
175+
size_t outputHeight = output[2];
176+
size_t outputWidth = output[3];
177+
178+
real* inputData = inputs[0].data<real>();
179+
real* filterData = inputs[1].data<real>();
180+
real* outputData = outputs[0].data<real>();
181+
bool needIm2col = isNeedIm2col(filter);
182+
183+
TensorShape imShape =
184+
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
185+
186+
TensorShape colShape;
187+
real* colData = NULL;
188+
189+
size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth;
190+
size_t colWidth = outputHeight * outputWidth;
191+
// Max col matrix height 256, Max col matrix width 1024
192+
size_t stepColHeight = std::min(colHeight, static_cast<size_t>(256));
193+
size_t stepColWidth = std::min(colWidth, static_cast<size_t>(2048));
194+
195+
if (needIm2col) {
196+
colShape = TensorShape({inputChannels / groups_,
197+
filterHeight,
198+
filterWidth,
199+
outputHeight,
200+
outputWidth});
201+
202+
resizeBuffer<Device>(stepColHeight * stepColWidth * sizeof(real));
203+
colData = reinterpret_cast<real*>(memory_->getBuf());
204+
}
205+
206+
Im2ColMobileFunctor<real> im2col;
207+
size_t inputOffset = imShape.getElements();
208+
size_t outputOffset =
209+
(outputChannels / groups_) * outputHeight * outputWidth;
210+
size_t filterOffset = filter.getElements() / groups_;
211+
212+
int nStride = colWidth;
213+
int kStride = colHeight;
214+
for (size_t i = 0; i < batchSize; i++) {
215+
for (size_t g = 0; g < groups_; g++) {
216+
if (needIm2col) {
217+
real beta_ = beta;
218+
for (size_t colHeightStart = 0; colHeightStart < colHeight;
219+
colHeightStart += stepColHeight) {
220+
for (size_t colWidthStart = 0; colWidthStart < colWidth;
221+
colWidthStart += stepColWidth) {
222+
int N = std::min(colWidth - colWidthStart, stepColWidth);
223+
int K = std::min(colHeight - colHeightStart, stepColHeight);
224+
// im2col
225+
im2col(inputData + g * inputOffset,
226+
imShape,
227+
colData,
228+
colShape,
229+
strideH(),
230+
strideW(),
231+
paddingH(),
232+
paddingW(),
233+
dilationH(),
234+
dilationW(),
235+
colHeightStart,
236+
K,
237+
colWidthStart,
238+
N);
239+
240+
// gemm
241+
int M = outputChannels / groups_;
242+
BlasGemm<Device, real>::compute(
243+
false,
244+
false,
245+
M,
246+
N,
247+
K,
248+
1.0f,
249+
filterData + g * filterOffset + colHeightStart,
250+
kStride,
251+
colData,
252+
N,
253+
beta_,
254+
outputData + g * outputOffset + colWidthStart,
255+
nStride);
256+
}
257+
beta_ = 1.0;
258+
}
259+
} else {
260+
int M = outputChannels / groups_;
261+
int N = outputHeight * outputWidth;
262+
int K = inputChannels / groups_ * filterHeight * filterWidth;
263+
BlasGemm<Device, real>::compute(false,
264+
false,
265+
M,
266+
N,
267+
K,
268+
1.0f,
269+
filterData + g * filterOffset,
270+
K,
271+
inputData + g * inputOffset,
272+
N,
273+
beta,
274+
outputData + g * outputOffset,
275+
N);
276+
}
277+
}
278+
inputData += inputChannels * inputHeight * inputWidth;
279+
outputData += outputChannels * outputHeight * outputWidth;
280+
}
281+
282+
memory_.reset();
134283
}
135284
};
136285

286+
#endif
287+
137288
/*
138289
* \brief Backward input calculation of convolution.
139290
*/
@@ -348,7 +499,11 @@ class GemmConvGradFilterFunction : public ConvFunctionBase {
348499
}
349500
};
350501

502+
#ifdef PADDLE_MOBILE_INFERENCE
503+
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction);
504+
#else
351505
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
506+
#endif
352507
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
353508
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
354509
#ifdef PADDLE_WITH_CUDA

paddle/function/Im2Col.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,54 @@ class Col2ImFunctor {
9898
int dilationWidth = 1);
9999
};
100100

101+
template <class T>
102+
class Im2ColMobileFunctor {
103+
public:
104+
void operator()(const T* imData,
105+
const TensorShape& imShape,
106+
T* colData,
107+
const TensorShape& colShape,
108+
int strideHeight,
109+
int strideWidth,
110+
int paddingHeight,
111+
int paddingWidth,
112+
int dilationHeight,
113+
int dilationWidth,
114+
int colHeightStart,
115+
int colHeightSize,
116+
int colWidthStart,
117+
int colWidthSize) {
118+
int inputHeight = imShape[1];
119+
int inputWidth = imShape[2];
120+
int filterHeight = colShape[1];
121+
int filterWidth = colShape[2];
122+
int outputWidth = colShape[4];
123+
124+
for (int colh = 0; colh < colHeightSize; colh++) {
125+
int wOffset = (colHeightStart + colh) % filterWidth;
126+
int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight;
127+
int c_im = (colHeightStart + colh) / filterWidth / filterHeight;
128+
129+
for (int colw = 0; colw < colWidthSize; colw++) {
130+
int h = (colWidthStart + colw) / outputWidth;
131+
int w = (colWidthStart + colw) % outputWidth;
132+
133+
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
134+
int imColIdx = w * strideWidth + wOffset * dilationWidth;
135+
if ((imRowIdx - paddingHeight) < 0 ||
136+
(imRowIdx - paddingHeight) >= inputHeight ||
137+
(imColIdx - paddingWidth) < 0 ||
138+
(imColIdx - paddingWidth) >= inputWidth) {
139+
colData[colh * colWidthSize + colw] = static_cast<T>(0);
140+
} else {
141+
imRowIdx += c_im * inputHeight - paddingHeight;
142+
imColIdx -= paddingWidth;
143+
colData[colh * colWidthSize + colw] =
144+
imData[imRowIdx * inputWidth + imColIdx];
145+
}
146+
}
147+
}
148+
}
149+
};
150+
101151
} // namespace paddle

paddle/function/Im2ColTest.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,86 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
138138

139139
#endif
140140

141+
template <class T>
142+
void TestIm2ColMobileFunctor() {
143+
for (size_t channels : {32}) {
144+
for (size_t inputHeight : {33, 100}) {
145+
for (size_t inputWidth : {32, 96}) {
146+
for (size_t filterHeight : {5}) {
147+
for (size_t filterWidth : {7}) {
148+
for (size_t stride : {2}) {
149+
for (size_t padding : {1}) {
150+
for (size_t dilation : {1, 3}) {
151+
size_t filterSizeH = (filterHeight - 1) * dilation + 1;
152+
size_t filterSizeW = (filterWidth - 1) * dilation + 1;
153+
if (inputHeight + 2 * padding < filterSizeH ||
154+
inputWidth + 2 * padding < filterSizeW)
155+
break;
156+
if (padding >= filterSizeH || padding >= filterSizeW) break;
157+
size_t outputHeight =
158+
(inputHeight - filterSizeH + 2 * padding) / stride + 1;
159+
size_t outputWidth =
160+
(inputWidth - filterSizeW + 2 * padding) / stride + 1;
161+
162+
TensorShape imShape =
163+
TensorShape({channels, inputHeight, inputWidth});
164+
TensorShape colShape1 = TensorShape({channels,
165+
filterHeight,
166+
filterWidth,
167+
outputHeight,
168+
outputWidth});
169+
170+
size_t height = channels * filterHeight * filterWidth;
171+
size_t width = outputHeight * outputWidth;
172+
VectorPtr input1 =
173+
Vector::create(imShape.getElements(), false);
174+
VectorPtr input2 =
175+
Vector::create(imShape.getElements(), false);
176+
MatrixPtr output1 =
177+
Matrix::create(height, width, false, false);
178+
MatrixPtr output2 =
179+
Matrix::create(height, width, false, false);
180+
input1->uniform(0.001, 1);
181+
input2->copyFrom(*input1);
182+
183+
Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> im2Col1;
184+
Im2ColMobileFunctor<T> im2Col2;
185+
im2Col1(input1->getData(),
186+
imShape,
187+
output1->getData(),
188+
colShape1,
189+
stride,
190+
stride,
191+
padding,
192+
padding,
193+
dilation,
194+
dilation);
195+
im2Col2(input2->getData(),
196+
imShape,
197+
output2->getData(),
198+
colShape1,
199+
stride,
200+
stride,
201+
padding,
202+
padding,
203+
dilation,
204+
dilation,
205+
0,
206+
height,
207+
0,
208+
width);
209+
210+
autotest::TensorCheckEqual(*output1, *output2);
211+
}
212+
}
213+
}
214+
}
215+
}
216+
}
217+
}
218+
}
219+
}
220+
221+
TEST(Im2ColFunctor, Mobile) { TestIm2ColMobileFunctor<float>(); }
222+
141223
} // namespace paddle

0 commit comments

Comments
 (0)