@@ -42,26 +42,28 @@ __global__ void FusedFillIf(T** outs, const size_t xs_size,
4242 const int tid = threadIdx .x + blockIdx .x * blockDim .x ;
4343
4444 // copy starts array from global memory to shared memory
45- extern __shared__ int64_t starts_s [];
45+ extern __shared__ int64_t s_starts [];
4646 for (int i = threadIdx .x ; i <= xs_size; i += blockDim .x ) {
47- starts_s [i] = starts[i];
47+ s_starts [i] = starts[i];
4848 }
4949 __syncthreads ();
5050
51- const int64_t total_num = starts_s [xs_size];
51+ const int64_t total_num = s_starts [xs_size];
5252 int out_index = 0 ;
5353
5454 for (int64_t id = tid; id < total_num; id += blockDim .x * gridDim .x ) {
5555 // get the "out" index of "id"
56+ // For example:
57+ // id = 15, starts = [0, 10, 10, 20, 30]
58+ // because 10 <= id < 20 ==>
59+ // the id element locate in the 3rd tensor (notice the 2nd tensor size is 0)
5660 int next_out_index = out_index;
57- while (id < starts_s[next_out_index]) next_out_index++;
58- // avoid some tensor's numel is zero
59- while (id >= starts_s[next_out_index]) next_out_index++;
61+ while (id >= s_starts[next_out_index]) next_out_index++;
6062 out_index = next_out_index - 1 ;
6163
6264 // get data pointer and index
6365 T* out_data = outs[out_index];
64- int64_t idx = id - starts_s [out_index];
66+ int64_t idx = id - s_starts [out_index];
6567
6668 // set value
6769 out_data[idx] = value;
@@ -93,48 +95,51 @@ class LazyZeros<platform::CUDADeviceContext, T> {
9395 const std::vector<const framework::Tensor*>& xs,
9496 const std::vector<framework::Tensor*>& outs) const {
9597 size_t xs_size = xs.size ();
98+ const auto & cpu_place = platform::CPUPlace ();
9699 // alloc each tensor's start index and copy to device
97- auto starts_h_tensor =
98- memory::Alloc (platform::CPUPlace () , (xs_size + 1 ) * sizeof (int64_t ));
99- int64_t * starts_h = reinterpret_cast <int64_t *>(starts_h_tensor ->ptr ());
100+ auto h_in_starts_mem =
101+ memory::Alloc (cpu_place , (xs_size + 1 ) * sizeof (int64_t ));
102+ int64_t * h_starts = reinterpret_cast <int64_t *>(h_in_starts_mem ->ptr ());
100103
101- auto starts_d_tensor =
104+ auto d_in_starts_mem =
102105 memory::Alloc (dev_ctx, (xs_size + 1 ) * sizeof (int64_t ));
103- int64_t * starts_d = reinterpret_cast <int64_t *>(starts_d_tensor ->ptr ());
106+ int64_t * d_starts = reinterpret_cast <int64_t *>(d_in_starts_mem ->ptr ());
104107
105- starts_h[0 ] = 0 ;
108+ // the start index value of each tensor is
109+ // the sum of previous tensor's size. For example:
110+ // outs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
111+ h_starts[0 ] = 0 ;
106112 for (int i = 0 ; i < xs_size; i++) {
107- // the start index value of each tensor is
108- // the sum of previous tensor's size
109- starts_h[i + 1 ] = starts_h[i] + outs[i]->numel ();
113+ h_starts[i + 1 ] = h_starts[i] + outs[i]->numel ();
110114 }
111115 memory::Copy (BOOST_GET_CONST (platform::CUDAPlace, dev_ctx.GetPlace ()),
112- starts_d, platform::CPUPlace (), starts_h ,
113- (xs_size + 1 ) * sizeof ( int64_t ), dev_ctx.stream ());
116+ d_starts, cpu_place, h_starts, (xs_size + 1 ) * sizeof ( int64_t ) ,
117+ dev_ctx.stream ());
114118
115119 // copy each tensor of "outs" data address array to device
116- auto outs_addr_h_tensor =
117- memory::Alloc (platform::CPUPlace (), xs_size * sizeof (T*));
118- T** outs_addr_h = reinterpret_cast <T**>(outs_addr_h_tensor->ptr ());
120+ auto h_out_addrs_tensor = memory::Alloc (cpu_place, xs_size * sizeof (T*));
121+ T** h_out_addrs = reinterpret_cast <T**>(h_out_addrs_tensor->ptr ());
119122
120- auto outs_addr_d_tensor = memory::Alloc (dev_ctx, xs_size * sizeof (T*));
121- T** outs_addr_d = reinterpret_cast <T**>(outs_addr_d_tensor ->ptr ());
123+ auto d_out_addrs_tensor = memory::Alloc (dev_ctx, xs_size * sizeof (T*));
124+ T** d_out_addrs = reinterpret_cast <T**>(d_out_addrs_tensor ->ptr ());
122125
123126 for (size_t i = 0 ; i < xs_size; ++i) {
124- outs_addr_h [i] = outs[i]->mutable_data <T>(dev_ctx.GetPlace ());
127+ h_out_addrs [i] = outs[i]->mutable_data <T>(dev_ctx.GetPlace ());
125128 }
126129 memory::Copy (BOOST_GET_CONST (platform::CUDAPlace, dev_ctx.GetPlace ()),
127- outs_addr_d, platform::CPUPlace (), outs_addr_h ,
128- xs_size * sizeof (T*), dev_ctx.stream ());
130+ d_out_addrs, cpu_place, h_out_addrs, xs_size * sizeof (T*) ,
131+ dev_ctx.stream ());
129132
130133 // launch cuda kernel
131- int64_t total_num = starts_h[xs_size];
132- int64_t block = std::min (static_cast <int64_t >(1024 ), total_num);
133- int64_t block_num = block * 50 ; // each thread deal with 50 data
134- int64_t grid = (total_num + block_num - 1 ) / block_num;
135- FusedFillIf<
136- T><<<grid, block, (xs_size + 1 ) * sizeof (int64_t ), dev_ctx.stream()>>> (
137- outs_addr_d, xs_size, starts_d, static_cast <T>(0 ), found_inf_data);
134+ int64_t total_num = h_starts[xs_size];
135+ int64_t threads_per_block = std::min (static_cast <int64_t >(1024 ), total_num);
136+ int64_t elements_per_block =
137+ threads_per_block * 50 ; // each thread deal with 50 data
138+ int64_t blocks_per_grid =
139+ (total_num + elements_per_block - 1 ) / elements_per_block;
140+ FusedFillIf<T><<<blocks_per_grid, threads_per_block,
141+ (xs_size + 1 ) * sizeof (int64_t ), dev_ctx.stream()>>> (
142+ d_out_addrs, xs_size, d_starts, static_cast <T>(0 ), found_inf_data);
138143 }
139144};
140145
0 commit comments