@@ -39,33 +39,36 @@ __global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale,
3939 __syncthreads ();
4040
4141 const int64_t num = s_starts[size];
42- int pre_xs_index = 0 ;
43- bool t_found_inf = false ;
44- const MT t_scale = *scale;
42+ int xs_index = 0 ;
43+ bool local_found_inf = false ;
44+ const MT local_scale = *scale;
4545 for (int64_t idx = tid; idx < num; idx += gridDim .x * blockDim .x ) {
46- // get the xs's index of thread
47- int xs_index = pre_xs_index;
48- while (idx < s_starts[xs_index]) xs_index++;
49- // avoid some tensor's numel is zero
50- while (idx >= s_starts[xs_index]) xs_index++;
51- pre_xs_index = xs_index - 1 ;
46+ // get the "out" index of "id"
47+ // For example:
48+ // idx = 15, starts = [0, 10, 10, 20, 30]
49+ // because 10 <= idx < 20 ==>
50+ // the idx element locate in the 3rd tensor (notice the 2nd tensor size is
51+ // 0)
52+ int next_xs_index = xs_index;
53+ while (idx >= s_starts[next_xs_index]) next_xs_index++;
54+ xs_index = next_xs_index - 1 ;
5255
5356 // get in data and out data
54- const T* in = xs[pre_xs_index ];
55- T* out = outs[pre_xs_index ];
56- int64_t in_idx = idx - s_starts[pre_xs_index ];
57+ const T* in = xs[xs_index ];
58+ T* out = outs[xs_index ];
59+ int64_t in_idx = idx - s_starts[xs_index ];
5760
5861 // Unscale
59- MT val = static_cast <MT>(in[in_idx]) * t_scale ;
62+ MT val = static_cast <MT>(in[in_idx]) * local_scale ;
6063 T narrow_val = static_cast <T>(val);
6164 out[in_idx] = narrow_val;
6265
6366 // CheckFinite
6467 if (!isfinite (narrow_val)) {
65- t_found_inf = true ;
68+ local_found_inf = true ;
6669 }
6770 }
68- if (t_found_inf ) {
71+ if (local_found_inf ) {
6972 *found_inf = true ;
7073 }
7174}
@@ -94,28 +97,30 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
9497 scale_data, inverse_scale_v, found_inf_data);
9598
9699 size_t xs_size = xs.size ();
100+ const auto & cpu_place = platform::CPUPlace ();
97101 // calculate each tensor's start index and copy to device
98102 auto h_starts_tensor =
99- memory::Alloc (platform::CPUPlace () , (xs_size + 1 ) * sizeof (int64_t ));
103+ memory::Alloc (cpu_place , (xs_size + 1 ) * sizeof (int64_t ));
100104 int64_t * h_starts = reinterpret_cast <int64_t *>(h_starts_tensor->ptr ());
101105
102106 auto d_starts_tensor =
103107 memory::Alloc (dev_ctx, (xs_size + 1 ) * sizeof (int64_t ));
104108 int64_t * d_starts = reinterpret_cast <int64_t *>(d_starts_tensor->ptr ());
105109
110+ // the start index value of each tensor is
111+ // the sum of previous tensor's size. For example:
112+ // xs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
106113 h_starts[0 ] = 0 ;
107114 for (int i = 1 ; i <= xs_size; i++) {
108- // the start index value of each tensor is
109- // the sum of previous tensor's size
110115 h_starts[i] = h_starts[i - 1 ] + xs[i - 1 ]->numel ();
111116 }
112117 int64_t total_num = h_starts[xs_size];
113118 memory::Copy (BOOST_GET_CONST (platform::CUDAPlace, dev_ctx.GetPlace ()),
114- d_starts, platform::CPUPlace () , h_starts,
115- (xs_size + 1 ) * sizeof ( int64_t ), dev_ctx.stream ());
119+ d_starts, cpu_place , h_starts, (xs_size + 1 ) * sizeof ( int64_t ) ,
120+ dev_ctx.stream ());
116121
117122 // copy each tensor's data address to device
118- auto h_mem = memory::Alloc (platform::CPUPlace () , 2 * xs_size * sizeof (T*));
123+ auto h_mem = memory::Alloc (cpu_place , 2 * xs_size * sizeof (T*));
119124 const T** h_xs = reinterpret_cast <const T**>(h_mem->ptr ());
120125 T** h_outs = reinterpret_cast <T**>(h_mem->ptr ()) + xs_size;
121126
@@ -128,16 +133,18 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
128133 h_outs[i] = outs[i]->mutable_data <T>(dev_ctx.GetPlace ());
129134 }
130135 memory::Copy (BOOST_GET_CONST (platform::CUDAPlace, dev_ctx.GetPlace ()), d_xs,
131- platform::CPUPlace (), h_xs, 2 * xs_size * sizeof (T*),
132- dev_ctx.stream ());
136+ cpu_place, h_xs, 2 * xs_size * sizeof (T*), dev_ctx.stream ());
133137
134138 // Launch Kernel
135- int block = 1024 ;
136- int block_num = block * 20 ; // each thread deal with 20 number
137- int grid = (total_num + block_num - 1 ) / block_num;
139+ int threads_per_block = std::min (static_cast <int64_t >(1024 ), total_num);
140+ int elements_per_block =
141+ threads_per_block * 20 ; // each thread deal with 20 number
142+ int blocks_per_grid =
143+ (total_num + elements_per_block - 1 ) / elements_per_block;
138144 VLOG (3 ) << " launch kernel" ;
139- CheckFiniteAndUnscale<T, MPDType><<<
140- grid, block, (xs_size + 1 ) * sizeof (int64_t ), dev_ctx.stream()>>> (
145+ CheckFiniteAndUnscale<
146+ T, MPDType><<<blocks_per_grid, threads_per_block,
147+ (xs_size + 1 ) * sizeof (int64_t ), dev_ctx.stream()>>> (
141148 d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
142149 VLOG (3 ) << " finish kernel" ;
143150 }
0 commit comments