@@ -210,7 +210,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
210210 const int padH,
211211 const int padW,
212212 real* tgtData,
213- const int tgtStride) {
213+ const int tgtStride,
214+ const bool excludeMode) {
214215 int index = blockIdx .x * blockDim .x + threadIdx .x ;
215216 if (index < nthreads) {
216217 int pw = index % pooledW;
@@ -224,7 +225,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
224225 int wend = min (wstart + sizeX, width);
225226 hstart = max (hstart, 0 );
226227 wstart = max (wstart, 0 );
227- int pool_size = (hend - hstart) * (wend - wstart);
228+ int poolSize =
229+ excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
228230
229231 real aveval = 0 ;
230232 inputData += (frameNum * channels + c) * height * width;
@@ -235,7 +237,7 @@ __global__ void KeAvgPoolForward(const int nthreads,
235237 }
236238 int tgtIndex =
237239 index % (pooledW * pooledH * channels) + frameNum * tgtStride;
238- tgtData[tgtIndex] = aveval / pool_size ;
240+ tgtData[tgtIndex] = aveval / poolSize ;
239241 }
240242}
241243
@@ -253,7 +255,8 @@ void hl_avgpool_forward(const int frameCnt,
253255 const int paddingH,
254256 const int paddingW,
255257 real* tgtData,
256- const int tgtStride) {
258+ const int tgtStride,
259+ const bool excludeMode) {
257260 int num_kernels = pooledH * pooledW * channels * frameCnt;
258261 int blocks = (num_kernels + 1024 - 1 ) / 1024 ;
259262 KeAvgPoolForward<<<blocks, 1024 , 0 , STREAM_DEFAULT>>> (num_kernels,
@@ -270,7 +273,8 @@ void hl_avgpool_forward(const int frameCnt,
270273 paddingH,
271274 paddingW,
272275 tgtData,
273- tgtStride);
276+ tgtStride,
277+ excludeMode);
274278 CHECK_SYNC (" hl_avgpool_forward failed" );
275279}
276280
@@ -290,7 +294,8 @@ __global__ void KeAvgPoolBackward(const int nthreads,
290294 real scaleA,
291295 real scaleB,
292296 real* tgtGrad,
293- const int outStride) {
297+ const int outStride,
298+ const bool excludeMode) {
294299 int index = blockIdx .x * blockDim .x + threadIdx .x ;
295300 if (index < nthreads) {
296301 int offsetW = index % width + padW;
@@ -314,8 +319,9 @@ __global__ void KeAvgPoolBackward(const int nthreads,
314319 int wstart = pw * strideW - padW;
315320 int wend = min (wstart + sizeX, width);
316321 wstart = max (wstart, 0 );
317- int poolsize = (hend - hstart) * (wend - wstart);
318- gradient += outGrad[ph * pooledW + pw] / poolsize;
322+ int poolSize =
323+ excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
324+ gradient += outGrad[ph * pooledW + pw] / poolSize;
319325 }
320326 }
321327 tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
@@ -338,7 +344,8 @@ void hl_avgpool_backward(const int frameCnt,
338344 real scaleA,
339345 real scaleB,
340346 real* backGrad,
341- const int outStride) {
347+ const int outStride,
348+ const bool excludeMode) {
342349 int num_kernels = height * width * channels * frameCnt;
343350 int blocks = (num_kernels + 1024 - 1 ) / 1024 ;
344351
@@ -358,7 +365,8 @@ void hl_avgpool_backward(const int frameCnt,
358365 scaleA,
359366 scaleB,
360367 backGrad,
361- outStride);
368+ outStride,
369+ excludeMode);
362370 CHECK_SYNC (" hl_avgpool_backward failed" );
363371}
364372
0 commit comments