@@ -126,14 +126,165 @@ class GemmConvFunction : public ConvFunctionBase {
126126 inputData += inputChannels * inputHeight * inputWidth;
127127 outputData += outputChannels * outputHeight * outputWidth;
128128 }
129+ }
130+ };
131+
129132#ifdef PADDLE_MOBILE_INFERENCE
130- if (Device == DEVICE_TYPE_CPU) {
131- memory_.reset ();
133+
134+ /*
135+ * \brief Forward calculation of convolution, optimized for mobile.
136+ */
137+ template <DeviceType Device>
138+ class GemmConvMobileFunction : public ConvFunctionBase {
139+ public:
140+ void init (const FuncConfig& config) override {
141+ ConvFunctionBase::init (config);
142+ }
143+
144+ void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
145+ const TensorShape& input = inputs[0 ].shape ();
146+ const TensorShape& filter = inputs[1 ].shape ();
147+ const TensorShape& output = outputs[0 ].shape ();
148+ checkShape (input, filter, output);
149+ }
150+
151+ void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
152+ CHECK_EQ (numInputs_, inputs.size ());
153+ CHECK_EQ (numOutputs_, outputs.size ());
154+ check (inputs, outputs);
155+ // TODO(hedaoyuan): Need to define some index macros,
156+ // to avoid useing 0 and 1.
157+ const TensorShape& input = inputs[0 ].shape ();
158+ const TensorShape& filter = inputs[1 ].shape ();
159+ const TensorShape& output = outputs[0 ].shape ();
160+
161+ real beta;
162+ if (outputs[0 ].getArgType () == ADD_TO) {
163+ beta = 1.0 ;
164+ } else {
165+ beta = 0.0 ;
132166 }
133- #endif
167+
168+ size_t batchSize = input[0 ];
169+ size_t inputChannels = input[1 ];
170+ size_t inputHeight = input[2 ];
171+ size_t inputWidth = input[3 ];
172+ size_t filterHeight = getFilterHeight (filter);
173+ size_t filterWidth = getFilterWidth (filter);
174+ size_t outputChannels = output[1 ];
175+ size_t outputHeight = output[2 ];
176+ size_t outputWidth = output[3 ];
177+
178+ real* inputData = inputs[0 ].data <real>();
179+ real* filterData = inputs[1 ].data <real>();
180+ real* outputData = outputs[0 ].data <real>();
181+ bool needIm2col = isNeedIm2col (filter);
182+
183+ TensorShape imShape =
184+ TensorShape ({inputChannels / groups_, inputHeight, inputWidth});
185+
186+ TensorShape colShape;
187+ real* colData = NULL ;
188+
189+ size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth;
190+ size_t colWidth = outputHeight * outputWidth;
191+ // Max col matrix height 256, Max col matrix width 1024
192+ size_t stepColHeight = std::min (colHeight, static_cast <size_t >(256 ));
193+ size_t stepColWidth = std::min (colWidth, static_cast <size_t >(2048 ));
194+
195+ if (needIm2col) {
196+ colShape = TensorShape ({inputChannels / groups_,
197+ filterHeight,
198+ filterWidth,
199+ outputHeight,
200+ outputWidth});
201+
202+ resizeBuffer<Device>(stepColHeight * stepColWidth * sizeof (real));
203+ colData = reinterpret_cast <real*>(memory_->getBuf ());
204+ }
205+
206+ Im2ColMobileFunctor<real> im2col;
207+ size_t inputOffset = imShape.getElements ();
208+ size_t outputOffset =
209+ (outputChannels / groups_) * outputHeight * outputWidth;
210+ size_t filterOffset = filter.getElements () / groups_;
211+
212+ int nStride = colWidth;
213+ int kStride = colHeight;
214+ for (size_t i = 0 ; i < batchSize; i++) {
215+ for (size_t g = 0 ; g < groups_; g++) {
216+ if (needIm2col) {
217+ real beta_ = beta;
218+ for (size_t colHeightStart = 0 ; colHeightStart < colHeight;
219+ colHeightStart += stepColHeight) {
220+ for (size_t colWidthStart = 0 ; colWidthStart < colWidth;
221+ colWidthStart += stepColWidth) {
222+ int N = std::min (colWidth - colWidthStart, stepColWidth);
223+ int K = std::min (colHeight - colHeightStart, stepColHeight);
224+ // im2col
225+ im2col (inputData + g * inputOffset,
226+ imShape,
227+ colData,
228+ colShape,
229+ strideH (),
230+ strideW (),
231+ paddingH (),
232+ paddingW (),
233+ dilationH (),
234+ dilationW (),
235+ colHeightStart,
236+ K,
237+ colWidthStart,
238+ N);
239+
240+ // gemm
241+ int M = outputChannels / groups_;
242+ BlasGemm<Device, real>::compute (
243+ false ,
244+ false ,
245+ M,
246+ N,
247+ K,
248+ 1 .0f ,
249+ filterData + g * filterOffset + colHeightStart,
250+ kStride ,
251+ colData,
252+ N,
253+ beta_,
254+ outputData + g * outputOffset + colWidthStart,
255+ nStride);
256+ }
257+ beta_ = 1.0 ;
258+ }
259+ } else {
260+ int M = outputChannels / groups_;
261+ int N = outputHeight * outputWidth;
262+ int K = inputChannels / groups_ * filterHeight * filterWidth;
263+ BlasGemm<Device, real>::compute (false ,
264+ false ,
265+ M,
266+ N,
267+ K,
268+ 1 .0f ,
269+ filterData + g * filterOffset,
270+ K,
271+ inputData + g * inputOffset,
272+ N,
273+ beta,
274+ outputData + g * outputOffset,
275+ N);
276+ }
277+ }
278+ inputData += inputChannels * inputHeight * inputWidth;
279+ outputData += outputChannels * outputHeight * outputWidth;
280+ }
281+
282+ memory_.reset ();
134283 }
135284};
136285
286+ #endif
287+
137288/*
138289 * \brief Backward input calculation of convolution.
139290 */
@@ -348,7 +499,11 @@ class GemmConvGradFilterFunction : public ConvFunctionBase {
348499 }
349500};
350501
502+ #ifdef PADDLE_MOBILE_INFERENCE
503+ REGISTER_TYPED_FUNC (GemmConv, CPU, GemmConvMobileFunction);
504+ #else
351505REGISTER_TYPED_FUNC (GemmConv, CPU, GemmConvFunction);
506+ #endif
352507REGISTER_TYPED_FUNC (GemmConvGradInput, CPU, GemmConvGradInputFunction);
353508REGISTER_TYPED_FUNC (GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
354509#ifdef PADDLE_WITH_CUDA
0 commit comments