@@ -26,18 +26,48 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
2626}
2727
2828template <typename T, typename MT>
29- __global__ void CheckFiniteAndUnscale (const T* in, const MT* scale, int num,
30- bool * found_inf, T* out) {
31- const int idx = threadIdx .x + blockIdx .x * blockDim .x ;
32-
33- if (idx < num) {
34- MT val = static_cast <MT>(in[idx]) * (*scale);
29+ __global__ void CheckFiniteAndUnscale (const T** xs, const MT* scale,
30+ int64_t size, int64_t * starts,
31+ bool * found_inf, T** outs) {
32+ const int64_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
33+
34+ // copy starts array from global memory to shared memory
35+ extern __shared__ int64_t s_starts[];
36+ for (int i = threadIdx .x ; i <= size; i += blockDim .x ) {
37+ s_starts[i] = starts[i];
38+ }
39+ __syncthreads ();
40+
41+ const int64_t num = s_starts[size];
42+ int pre_xs_index = 0 ;
43+ bool t_found_inf = false ;
44+ const MT t_scale = *scale;
45+ for (int64_t idx = tid; idx < num; idx += gridDim .x * blockDim .x ) {
46+ // get the xs's index of thread
47+ int xs_index = pre_xs_index;
48+ while (idx < s_starts[xs_index]) xs_index++;
49+ // avoid some tensor's numel is zero
50+ while (idx >= s_starts[xs_index]) xs_index++;
51+ pre_xs_index = xs_index - 1 ;
52+
53+ // get in data and out data
54+ const T* in = xs[pre_xs_index];
55+ T* out = outs[pre_xs_index];
56+ int64_t in_idx = idx - s_starts[pre_xs_index];
57+
58+ // Unscale
59+ MT val = static_cast <MT>(in[in_idx]) * t_scale;
3560 T narrow_val = static_cast <T>(val);
36- out[idx] = narrow_val;
61+ out[in_idx] = narrow_val;
62+
63+ // CheckFinite
3764 if (!isfinite (narrow_val)) {
38- *found_inf = true ;
65+ t_found_inf = true ;
3966 }
4067 }
68+ if (t_found_inf) {
69+ *found_inf = true ;
70+ }
4171}
4272
4373template <typename T>
@@ -63,20 +93,53 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
6393 InverseAndMemset<MPDType><<<1 , 1 , 0 , dev_ctx.stream()>>> (
6494 scale_data, inverse_scale_v, found_inf_data);
6595
66- for (size_t i = 0 ; i < xs.size (); ++i) {
67- const auto * x = xs[i];
68- auto * out = outs[i];
69- const T* x_data = x->data <T>();
70- T* out_data = out->mutable_data <T>(dev_ctx.GetPlace ());
71-
72- int num = x->numel ();
73- int block = 1024 ;
74- int grid = (num + block - 1 ) / block;
75- VLOG (3 ) << " launch kernel" ;
76- CheckFiniteAndUnscale<T, MPDType><<<grid, block, 0 , dev_ctx.stream()>>> (
77- x_data, inverse_scale_v, num, found_inf_data, out_data);
78- VLOG (3 ) << " finish kernel" ;
96+ size_t xs_size = xs.size ();
97+ // calculate each tensor's start index and copy to device
98+ auto h_starts_tensor =
99+ memory::Alloc (platform::CPUPlace (), (xs_size + 1 ) * sizeof (int64_t ));
100+ int64_t * h_starts = reinterpret_cast <int64_t *>(h_starts_tensor->ptr ());
101+
102+ auto d_starts_tensor =
103+ memory::Alloc (dev_ctx, (xs_size + 1 ) * sizeof (int64_t ));
104+ int64_t * d_starts = reinterpret_cast <int64_t *>(d_starts_tensor->ptr ());
105+
106+ h_starts[0 ] = 0 ;
107+ for (int i = 1 ; i <= xs_size; i++) {
108+ // the start index value of each tensor is
109+ // the sum of previous tensor's size
110+ h_starts[i] = h_starts[i - 1 ] + xs[i - 1 ]->numel ();
111+ }
112+ int64_t total_num = h_starts[xs_size];
113+ memory::Copy (BOOST_GET_CONST (platform::CUDAPlace, dev_ctx.GetPlace ()),
114+ d_starts, platform::CPUPlace (), h_starts,
115+ (xs_size + 1 ) * sizeof (int64_t ), dev_ctx.stream ());
116+
117+ // copy each tensor's data address to device
118+ auto h_mem = memory::Alloc (platform::CPUPlace (), 2 * xs_size * sizeof (T*));
119+ const T** h_xs = reinterpret_cast <const T**>(h_mem->ptr ());
120+ T** h_outs = reinterpret_cast <T**>(h_mem->ptr ()) + xs_size;
121+
122+ auto d_mem = memory::Alloc (dev_ctx, 2 * xs_size * sizeof (T*));
123+ const T** d_xs = reinterpret_cast <const T**>(d_mem->ptr ());
124+ T** d_outs = reinterpret_cast <T**>(d_mem->ptr ()) + xs_size;
125+
126+ for (size_t i = 0 ; i < xs_size; ++i) {
127+ h_xs[i] = xs[i]->data <T>();
128+ h_outs[i] = outs[i]->mutable_data <T>(dev_ctx.GetPlace ());
79129 }
130+ memory::Copy (BOOST_GET_CONST (platform::CUDAPlace, dev_ctx.GetPlace ()), d_xs,
131+ platform::CPUPlace (), h_xs, 2 * xs_size * sizeof (T*),
132+ dev_ctx.stream ());
133+
134+ // Launch Kernel
135+ int block = 1024 ;
136+ int block_num = block * 20 ; // each thread deal with 20 number
137+ int grid = (total_num + block_num - 1 ) / block_num;
138+ VLOG (3 ) << " launch kernel" ;
139+ CheckFiniteAndUnscale<T, MPDType><<<
140+ grid, block, (xs_size + 1 ) * sizeof (int64_t ), dev_ctx.stream()>>> (
141+ d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
142+ VLOG (3 ) << " finish kernel" ;
80143 }
81144};
82145} // namespace operators
0 commit comments