@@ -381,164 +381,6 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
381381 CHECK_SYNC (" hl_avgpool_backward failed" );
382382}
383383
384- __global__ void KeCMRNormFillScale (size_t nthreads, const real* in,
385- real* scale, size_t channels,
386- size_t height, size_t width, size_t size,
387- real alpha) {
388- size_t index = threadIdx .x + blockIdx .x * blockDim .x ;
389- if (index < nthreads) {
390- // find out the local offset
391- size_t w = index % width;
392- size_t h = (index / width) % height;
393- size_t n = index / width / height;
394- size_t offset = (n * channels * height + h) * width + w;
395- size_t step = height * width;
396- in += offset;
397- scale += offset;
398- size_t head = 0 ;
399- size_t pre_pad = (size - 1 ) / 2 ;
400- size_t post_pad = size - pre_pad - 1 ;
401- real accum_scale = 0 ;
402- // fill the scale at [n, :, h, w]
403- // accumulate values
404- while (head < post_pad) {
405- accum_scale += in[head * step] * in[head * step];
406- ++head;
407- }
408- // until we reach size, nothing needs to be subtracted
409- while (head < size) {
410- accum_scale += in[head * step] * in[head * step];
411- scale[(head - post_pad) * step] = 1 . + accum_scale * alpha;
412- ++head;
413- }
414- // both add and subtract
415- while (head < channels) {
416- accum_scale += in[head * step] * in[head * step];
417- accum_scale -= in[(head - size) * step] * in[(head - size) * step];
418- scale[(head - post_pad) * step] = 1 . + accum_scale * alpha;
419- ++head;
420- }
421- // subtract only
422- while (head < channels + post_pad) {
423- accum_scale -= in[(head - size) * step] * in[(head - size) * step];
424- scale[(head - post_pad) * step] = 1 . + accum_scale * alpha;
425- ++head;
426- }
427- }
428- }
429-
430- __global__ void KeCMRNormOutput (size_t nthreads, const real* in,
431- const real* scale, real negative_beta,
432- real* out) {
433- size_t index = threadIdx .x + blockIdx .x * blockDim .x ;
434- if (index < nthreads) {
435- out[index] = in[index] * pow (scale[index], negative_beta);
436- }
437- }
438-
439- void hl_CMRNorm_forward (size_t frameCnt, const real* in, real* scale,
440- real* out, size_t channels,
441- size_t height, size_t width, size_t sizeX,
442- real alpha, real beta) {
443- size_t threadsNum = frameCnt * height * width;
444- size_t blocksX = (threadsNum + 1024 - 1 ) / 1024 ;
445- size_t blocksY = 1 ;
446- dim3 threads (1024 , 1 );
447- dim3 grid (blocksX, blocksY);
448-
449- KeCMRNormFillScale<<<grid, threads, 0 , STREAM_DEFAULT>>>
450- (threadsNum, in, scale, channels, height, width, sizeX, alpha);
451-
452- threadsNum = frameCnt * height * width *channels;
453- blocksX = (threadsNum + 1024 -1 ) / 1024 ;
454- dim3 threads2 (1024 , 1 );
455- dim3 grid2 (blocksX, blocksY);
456- KeCMRNormOutput<<<grid2, threads2, 0 , STREAM_DEFAULT>>>
457- (threadsNum, in, scale, beta, out);
458- CHECK_SYNC (" hl_CMRNorm_forward" );
459- }
460-
461- __global__ void KeCMRNormDiff (size_t nthreads, const real* bottom_data,
462- const real* top_data, const real* scale,
463- const real* top_diff, size_t channels,
464- size_t height, size_t width, size_t size,
465- real negative_beta, real cache_ratio,
466- real* bottom_diff ) {
467- int index = threadIdx .x + blockIdx .x * blockDim .x ;
468- if (index < nthreads) {
469- // find out the local offset
470- size_t w = index % width;
471- size_t h = (index / width) % height;
472- size_t n = index / width / height;
473- size_t offset = (n * channels * height + h) * width + w;
474- size_t step = height * width;
475- bottom_data += offset;
476- top_data += offset;
477- scale += offset;
478- top_diff += offset;
479- bottom_diff += offset;
480- int head = 0 ;
481- int pre_pad = size - (size + 1 ) / 2 ;
482- int post_pad = size - pre_pad - 1 ;
483- real accum_ratio = 0 ;
484- // accumulate values
485- while (head < post_pad) {
486- accum_ratio += top_diff[head * step] *
487- top_data[head * step] / scale[head * step];
488- ++head;
489- }
490- // until we reach size, nothing needs to be subtracted
491- while (head < size) {
492- accum_ratio += top_diff[head * step] *
493- top_data[head * step] / scale[head * step];
494- bottom_diff[(head - post_pad) * step] +=
495- top_diff[(head - post_pad) * step] *
496- pow (scale[(head - post_pad) * step], negative_beta) - cache_ratio *
497- bottom_data[(head - post_pad) * step] * accum_ratio;
498- ++head;
499- }
500- // both add and subtract
501- while (head < channels) {
502- accum_ratio += top_diff[head * step] * top_data[head * step] /
503- scale[head * step];
504- accum_ratio -= top_diff[(head - size) * step] *
505- top_data[(head - size) * step] / scale[(head - size) * step];
506- bottom_diff[(head - post_pad) * step] +=
507- top_diff[(head - post_pad) * step] *
508- pow (scale[(head - post_pad) * step], negative_beta) - cache_ratio *
509- bottom_data[(head - post_pad) * step] * accum_ratio;
510- ++head;
511- }
512- // subtract only
513- while (head < channels + post_pad) {
514- accum_ratio -= top_diff[(head - size) * step] *
515- top_data[(head - size) * step] / scale[(head - size) * step];
516- bottom_diff[(head - post_pad) * step] +=
517- top_diff[(head - post_pad) * step] *
518- pow (scale[(head - post_pad) * step], negative_beta) - cache_ratio *
519- bottom_data[(head - post_pad) * step] * accum_ratio;
520- ++head;
521- }
522- }
523- }
524-
525- void hl_CMRNorm_backward (size_t frameCnt, const real* inV,
526- const real* scale,
527- const real* outV, const real* outDiff,
528- real *inDiff, size_t channels,
529- size_t height, size_t width, size_t sizeX,
530- real alpha, real beta) {
531- size_t threadsNum = frameCnt * height * width;
532- size_t blocksX = (threadsNum + 1024 - 1 ) / 1024 ;
533- size_t blocksY = 1 ;
534- dim3 threads (1024 , 1 );
535- dim3 grid (blocksX, blocksY);
536- KeCMRNormDiff <<<grid, threads, 0 , STREAM_DEFAULT>>>
537- (threadsNum, inV, outV, scale, outDiff, channels,
538- height, width, sizeX, alpha, beta, inDiff);
539- CHECK_SYNC (" hl_CMRNorm_backward" );
540- }
541-
542384__global__ void KeBilinearInterpFw (const real* in,
543385 const size_t inImgH,
544386 const size_t inImgW,
0 commit comments