@@ -899,10 +899,10 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
899899 int pixelNum = n * out_cw;
900900
901901 platform::GpuLaunchConfig config =
902- platform::getGpuLaunchConfig (pixelNum, ctx );
902+ platform::GetGpuLaunchConfig1D (ctx. cuda_device_context (), pixelNum );
903903
904904 if (" linear" == interp_method) {
905- KeLinearInterpFw<T><<<config.blocks , config.threads , 0 ,
905+ KeLinearInterpFw<T><<<config.block_per_grid , config.thread_per_block , 0 ,
906906 ctx.cuda_device_context().stream()>>> (
907907 input_data, in_w, in_cw, output_data, out_w, n, out_cw, c, ratio_w,
908908 align_corners, align_mode, data_layout);
@@ -1018,21 +1018,22 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
10181018 int pixelNum = n * out_chw;
10191019
10201020 platform::GpuLaunchConfig config =
1021- platform::getGpuLaunchConfig (pixelNum, ctx );
1021+ platform::GetGpuLaunchConfig1D (ctx. cuda_device_context (), pixelNum );
10221022
10231023 if (" nearest" == interp_method) {
1024- KeNearestNeighborInterpFw<T><<<config.blocks, config.threads, 0 ,
1025- ctx.cuda_device_context().stream()>>> (
1024+ KeNearestNeighborInterpFw<
1025+ T><<<config.block_per_grid, config.thread_per_block, 0 ,
1026+ ctx.cuda_device_context().stream()>>> (
10261027 input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
10271028 out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
10281029 } else if (" bilinear" == interp_method) {
1029- KeBilinearInterpFw<T><<<config.blocks , config.threads , 0 ,
1030+ KeBilinearInterpFw<T><<<config.block_per_grid , config.thread_per_block , 0 ,
10301031 ctx.cuda_device_context().stream()>>> (
10311032 input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
10321033 out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
10331034 } else if (" bicubic" == interp_method) {
1034- KeBicubicInterpFw<
1035- T> <<<config.blocks, 512 , 0 , ctx.cuda_device_context().stream()>>> (
1035+ KeBicubicInterpFw<T> <<<config.block_per_grid, 512 , 0 ,
1036+ ctx.cuda_device_context().stream()>>> (
10361037 input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
10371038 out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
10381039 }
@@ -1167,10 +1168,10 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
11671168 int pixelNum = n * out_cdhw;
11681169
11691170 platform::GpuLaunchConfig config =
1170- platform::getGpuLaunchConfig (pixelNum, ctx );
1171+ platform::GetGpuLaunchConfig1D (ctx. cuda_device_context (), pixelNum );
11711172
11721173 if (" trilinear" == interp_method) {
1173- KeTrilinearInterpFw<T><<<config.blocks , config.threads , 0 ,
1174+ KeTrilinearInterpFw<T><<<config.block_per_grid , config.thread_per_block , 0 ,
11741175 ctx.cuda_device_context().stream()>>> (
11751176 input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h,
11761177 out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
@@ -1259,10 +1260,10 @@ static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
12591260 int pixelNum = n * out_cw;
12601261
12611262 platform::GpuLaunchConfig config =
1262- platform::getGpuLaunchConfig (pixelNum, ctx );
1263+ platform::GetGpuLaunchConfig1D (ctx. cuda_device_context (), pixelNum );
12631264
12641265 if (" linear" == interp_method) {
1265- KeLinearInterpBw<T><<<config.blocks , config.threads , 0 ,
1266+ KeLinearInterpBw<T><<<config.block_per_grid , config.thread_per_block , 0 ,
12661267 ctx.cuda_device_context().stream()>>> (
12671268 input_grad_data, in_w, in_cw, output_grad_data, out_w, n, out_cw, c,
12681269 ratio_w, align_corners, align_mode, data_layout);
@@ -1376,22 +1377,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
13761377 int pixelNum = n * out_chw;
13771378
13781379 platform::GpuLaunchConfig config =
1379- platform::getGpuLaunchConfig (pixelNum, ctx );
1380+ platform::GetGpuLaunchConfig1D (ctx. cuda_device_context (), pixelNum );
13801381
13811382 if (" nearest" == interp_method) {
1382- KeNearestNeighborInterpBw<T><<<config.blocks, config.threads, 0 ,
1383- ctx.cuda_device_context().stream()>>> (
1383+ KeNearestNeighborInterpBw<
1384+ T><<<config.block_per_grid, config.thread_per_block, 0 ,
1385+ ctx.cuda_device_context().stream()>>> (
13841386 input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
13851387 n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
13861388 } else if (" bilinear" == interp_method) {
1387- KeBilinearInterpBw<T><<<config.blocks , config.threads , 0 ,
1389+ KeBilinearInterpBw<T><<<config.block_per_grid , config.thread_per_block , 0 ,
13881390 ctx.cuda_device_context().stream()>>> (
13891391 input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
13901392 n, out_chw, c, ratio_h, ratio_w, align_corners, align_mode,
13911393 data_layout);
13921394 } else if (" bicubic" == interp_method) {
1393- KeBicubicInterpBw<
1394- T> <<<config.blocks, 512 , 0 , ctx.cuda_device_context().stream()>>> (
1395+ KeBicubicInterpBw<T> <<<config.block_per_grid, 512 , 0 ,
1396+ ctx.cuda_device_context().stream()>>> (
13951397 input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w,
13961398 n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
13971399 }
@@ -1520,10 +1522,10 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
15201522 int pixelNum = n * out_cdhw;
15211523
15221524 platform::GpuLaunchConfig config =
1523- platform::getGpuLaunchConfig (pixelNum, ctx );
1525+ platform::GetGpuLaunchConfig1D (ctx. cuda_device_context (), pixelNum );
15241526
15251527 if (" trilinear" == interp_method) {
1526- KeTrilinearInterpBw<T><<<config.blocks , config.threads , 0 ,
1528+ KeTrilinearInterpBw<T><<<config.block_per_grid , config.thread_per_block , 0 ,
15271529 ctx.cuda_device_context().stream()>>> (
15281530 input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d,
15291531 out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners,
0 commit comments