-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Optimize update_loss_scaling_op #32554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ad79dff
527779a
e7b8e48
1edd950
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,13 +34,37 @@ __global__ void GpuUpdateLossScaling( | |
| } | ||
|
|
||
| template <typename T> | ||
| __global__ void FillIf(T* data, const int64_t num, const T value, | ||
| const bool* has_inf) { | ||
| if (*has_inf) { | ||
| int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
| for (int i = tid; i < num; i += blockDim.x * gridDim.x) { | ||
| data[i] = value; | ||
| } | ||
| __global__ void FusedFillIf(T** outs, const size_t xs_size, | ||
| const int64_t* starts, const T value, | ||
| const bool* has_inf) { | ||
| if (!(*has_inf)) return; | ||
|
|
||
| const int tid = threadIdx.x + blockIdx.x * blockDim.x; | ||
|
|
||
| // copy starts array from global memory to shared memory | ||
| extern __shared__ int64_t starts_s[]; | ||
| for (int i = threadIdx.x; i <= xs_size; i += blockDim.x) { | ||
| starts_s[i] = starts[i]; | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| const int64_t total_num = starts_s[xs_size]; | ||
| int out_index = 0; | ||
|
|
||
| for (int64_t id = tid; id < total_num; id += blockDim.x * gridDim.x) { | ||
| // get the "out" index of "id" | ||
| int next_out_index = out_index; | ||
| while (id < starts_s[next_out_index]) next_out_index++; | ||
|
||
| // avoid some tensor's numel is zero | ||
| while (id >= starts_s[next_out_index]) next_out_index++; | ||
| out_index = next_out_index - 1; | ||
|
|
||
| // get data pointer and index | ||
| T* out_data = outs[out_index]; | ||
| int64_t idx = id - starts_s[out_index]; | ||
|
|
||
| // set value | ||
| out_data[idx] = value; | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -68,15 +92,49 @@ class LazyZeros<platform::CUDADeviceContext, T> { | |
| const bool* found_inf_data, | ||
| const std::vector<const framework::Tensor*>& xs, | ||
| const std::vector<framework::Tensor*>& outs) const { | ||
| for (size_t i = 0; i < xs.size(); ++i) { | ||
| auto* out = outs[i]; | ||
| T* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); | ||
| int64_t num = out->numel(); | ||
| int block = 1024; | ||
| int grid = (block - 1 + num) / block; | ||
| FillIf<<<grid, block, 0, dev_ctx.stream()>>>( | ||
| out_data, num, static_cast<T>(0), found_inf_data); | ||
| size_t xs_size = xs.size(); | ||
| // alloc each tensor's start index and copy to device | ||
| auto starts_h_tensor = | ||
| memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t)); | ||
|
||
| int64_t* starts_h = reinterpret_cast<int64_t*>(starts_h_tensor->ptr()); | ||
|
|
||
| auto starts_d_tensor = | ||
| memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t)); | ||
| int64_t* starts_d = reinterpret_cast<int64_t*>(starts_d_tensor->ptr()); | ||
|
||
|
|
||
| starts_h[0] = 0; | ||
| for (int i = 0; i < xs_size; i++) { | ||
| // the start index value of each tensor is | ||
| // the sum of previous tensor's size | ||
| starts_h[i + 1] = starts_h[i] + outs[i]->numel(); | ||
| } | ||
| memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), | ||
| starts_d, platform::CPUPlace(), starts_h, | ||
| (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()); | ||
|
|
||
| // copy each tensor of "outs" data address array to device | ||
| auto outs_addr_h_tensor = | ||
| memory::Alloc(platform::CPUPlace(), xs_size * sizeof(T*)); | ||
| T** outs_addr_h = reinterpret_cast<T**>(outs_addr_h_tensor->ptr()); | ||
|
|
||
| auto outs_addr_d_tensor = memory::Alloc(dev_ctx, xs_size * sizeof(T*)); | ||
| T** outs_addr_d = reinterpret_cast<T**>(outs_addr_d_tensor->ptr()); | ||
|
||
|
|
||
| for (size_t i = 0; i < xs_size; ++i) { | ||
| outs_addr_h[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace()); | ||
| } | ||
| memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), | ||
| outs_addr_d, platform::CPUPlace(), outs_addr_h, | ||
| xs_size * sizeof(T*), dev_ctx.stream()); | ||
|
|
||
| // launch cuda kernel | ||
| int64_t total_num = starts_h[xs_size]; | ||
| int64_t block = std::min(static_cast<int64_t>(1024), total_num); | ||
| int64_t block_num = block * 50; // each thread deal with 50 data | ||
| int64_t grid = (total_num + block_num - 1) / block_num; | ||
|
||
| FusedFillIf< | ||
| T><<<grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>( | ||
| outs_addr_d, xs_size, starts_d, static_cast<T>(0), found_inf_data); | ||
| } | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
starts_s-->s_startsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done