Skip to content

Commit 33703da

Browse files
authored
[Cherry-pick] Optimize update_loss_scaling_op(#32554) (#32606)
* optimize update_loss_scaling_op by fused for loop to one kernel, test=develop * remove useless while loop and optimize variable name, test=develop * optimize variable name from out_addrs_tensor to out_addrs_mem, test=develop * optimize variable name for readable by change prefix identifier from t_ to local_
1 parent 32203c3 commit 33703da

File tree

2 files changed

+113
-43
lines changed

2 files changed

+113
-43
lines changed

paddle/fluid/operators/amp/check_finite_and_unscale_op.cu

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,36 @@ __global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale,
3939
__syncthreads();
4040

4141
const int64_t num = s_starts[size];
42-
int pre_xs_index = 0;
43-
bool t_found_inf = false;
44-
const MT t_scale = *scale;
42+
int xs_index = 0;
43+
bool local_found_inf = false;
44+
const MT local_scale = *scale;
4545
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
46-
// get the xs's index of thread
47-
int xs_index = pre_xs_index;
48-
while (idx < s_starts[xs_index]) xs_index++;
49-
// avoid some tensor's numel is zero
50-
while (idx >= s_starts[xs_index]) xs_index++;
51-
pre_xs_index = xs_index - 1;
46+
// get the "out" index of "id"
47+
// For example:
48+
// idx = 15, starts = [0, 10, 10, 20, 30]
49+
// because 10 <= idx < 20 ==>
50+
// the idx element locate in the 3rd tensor (notice the 2nd tensor size is
51+
// 0)
52+
int next_xs_index = xs_index;
53+
while (idx >= s_starts[next_xs_index]) next_xs_index++;
54+
xs_index = next_xs_index - 1;
5255

5356
// get in data and out data
54-
const T* in = xs[pre_xs_index];
55-
T* out = outs[pre_xs_index];
56-
int64_t in_idx = idx - s_starts[pre_xs_index];
57+
const T* in = xs[xs_index];
58+
T* out = outs[xs_index];
59+
int64_t in_idx = idx - s_starts[xs_index];
5760

5861
// Unscale
59-
MT val = static_cast<MT>(in[in_idx]) * t_scale;
62+
MT val = static_cast<MT>(in[in_idx]) * local_scale;
6063
T narrow_val = static_cast<T>(val);
6164
out[in_idx] = narrow_val;
6265

6366
// CheckFinite
6467
if (!isfinite(narrow_val)) {
65-
t_found_inf = true;
68+
local_found_inf = true;
6669
}
6770
}
68-
if (t_found_inf) {
71+
if (local_found_inf) {
6972
*found_inf = true;
7073
}
7174
}
@@ -94,28 +97,30 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
9497
scale_data, inverse_scale_v, found_inf_data);
9598

9699
size_t xs_size = xs.size();
100+
const auto& cpu_place = platform::CPUPlace();
97101
// calculate each tensor's start index and copy to device
98102
auto h_starts_tensor =
99-
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
103+
memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
100104
int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());
101105

102106
auto d_starts_tensor =
103107
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
104108
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());
105109

110+
// the start index value of each tensor is
111+
// the sum of previous tensor's size. For example:
112+
// xs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
106113
h_starts[0] = 0;
107114
for (int i = 1; i <= xs_size; i++) {
108-
// the start index value of each tensor is
109-
// the sum of previous tensor's size
110115
h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
111116
}
112117
int64_t total_num = h_starts[xs_size];
113118
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
114-
d_starts, platform::CPUPlace(), h_starts,
115-
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream());
119+
d_starts, cpu_place, h_starts, (xs_size + 1) * sizeof(int64_t),
120+
dev_ctx.stream());
116121

117122
// copy each tensor's data address to device
118-
auto h_mem = memory::Alloc(platform::CPUPlace(), 2 * xs_size * sizeof(T*));
123+
auto h_mem = memory::Alloc(cpu_place, 2 * xs_size * sizeof(T*));
119124
const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
120125
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;
121126

@@ -128,16 +133,18 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
128133
h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
129134
}
130135
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs,
131-
platform::CPUPlace(), h_xs, 2 * xs_size * sizeof(T*),
132-
dev_ctx.stream());
136+
cpu_place, h_xs, 2 * xs_size * sizeof(T*), dev_ctx.stream());
133137

134138
// Launch Kernel
135-
int block = 1024;
136-
int block_num = block * 20; // each thread deal with 20 number
137-
int grid = (total_num + block_num - 1) / block_num;
139+
int threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
140+
int elements_per_block =
141+
threads_per_block * 20; // each thread deal with 20 number
142+
int blocks_per_grid =
143+
(total_num + elements_per_block - 1) / elements_per_block;
138144
VLOG(3) << "launch kernel";
139-
CheckFiniteAndUnscale<T, MPDType><<<
140-
grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
145+
CheckFiniteAndUnscale<
146+
T, MPDType><<<blocks_per_grid, threads_per_block,
147+
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
141148
d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
142149
VLOG(3) << "finish kernel";
143150
}

paddle/fluid/operators/amp/update_loss_scaling_op.cu

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,39 @@ __global__ void GpuUpdateLossScaling(
3434
}
3535

3636
template <typename T>
37-
__global__ void FillIf(T* data, const int64_t num, const T value,
38-
const bool* has_inf) {
39-
if (*has_inf) {
40-
int tid = threadIdx.x + blockIdx.x * blockDim.x;
41-
for (int i = tid; i < num; i += blockDim.x * gridDim.x) {
42-
data[i] = value;
43-
}
37+
__global__ void FusedFillIf(T** outs, const size_t xs_size,
38+
const int64_t* starts, const T value,
39+
const bool* has_inf) {
40+
if (!(*has_inf)) return;
41+
42+
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
43+
44+
// copy starts array from global memory to shared memory
45+
extern __shared__ int64_t s_starts[];
46+
for (int i = threadIdx.x; i <= xs_size; i += blockDim.x) {
47+
s_starts[i] = starts[i];
48+
}
49+
__syncthreads();
50+
51+
const int64_t total_num = s_starts[xs_size];
52+
int out_index = 0;
53+
54+
for (int64_t id = tid; id < total_num; id += blockDim.x * gridDim.x) {
55+
// get the "out" index of "id"
56+
// For example:
57+
// id = 15, starts = [0, 10, 10, 20, 30]
58+
// because 10 <= id < 20 ==>
59+
// the id element locate in the 3rd tensor (notice the 2nd tensor size is 0)
60+
int next_out_index = out_index;
61+
while (id >= s_starts[next_out_index]) next_out_index++;
62+
out_index = next_out_index - 1;
63+
64+
// get data pointer and index
65+
T* out_data = outs[out_index];
66+
int64_t idx = id - s_starts[out_index];
67+
68+
// set value
69+
out_data[idx] = value;
4470
}
4571
}
4672

@@ -68,15 +94,52 @@ class LazyZeros<platform::CUDADeviceContext, T> {
6894
const bool* found_inf_data,
6995
const std::vector<const framework::Tensor*>& xs,
7096
const std::vector<framework::Tensor*>& outs) const {
71-
for (size_t i = 0; i < xs.size(); ++i) {
72-
auto* out = outs[i];
73-
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
74-
int64_t num = out->numel();
75-
int block = 1024;
76-
int grid = (block - 1 + num) / block;
77-
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(
78-
out_data, num, static_cast<T>(0), found_inf_data);
97+
size_t xs_size = xs.size();
98+
const auto& cpu_place = platform::CPUPlace();
99+
// alloc each tensor's start index and copy to device
100+
auto h_in_starts_mem =
101+
memory::Alloc(cpu_place, (xs_size + 1) * sizeof(int64_t));
102+
int64_t* h_starts = reinterpret_cast<int64_t*>(h_in_starts_mem->ptr());
103+
104+
auto d_in_starts_mem =
105+
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
106+
int64_t* d_starts = reinterpret_cast<int64_t*>(d_in_starts_mem->ptr());
107+
108+
// the start index value of each tensor is
109+
// the sum of previous tensor's size. For example:
110+
// outs = [10, 0, 10, 10] ==> starts = [0, 10, 10, 20, 30]
111+
h_starts[0] = 0;
112+
for (int i = 0; i < xs_size; i++) {
113+
h_starts[i + 1] = h_starts[i] + outs[i]->numel();
79114
}
115+
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
116+
d_starts, cpu_place, h_starts, (xs_size + 1) * sizeof(int64_t),
117+
dev_ctx.stream());
118+
119+
// copy each tensor of "outs" data address array to device
120+
auto h_out_addrs_mem = memory::Alloc(cpu_place, xs_size * sizeof(T*));
121+
T** h_out_addrs = reinterpret_cast<T**>(h_out_addrs_mem->ptr());
122+
123+
auto d_out_addrs_mem = memory::Alloc(dev_ctx, xs_size * sizeof(T*));
124+
T** d_out_addrs = reinterpret_cast<T**>(d_out_addrs_mem->ptr());
125+
126+
for (size_t i = 0; i < xs_size; ++i) {
127+
h_out_addrs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
128+
}
129+
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
130+
d_out_addrs, cpu_place, h_out_addrs, xs_size * sizeof(T*),
131+
dev_ctx.stream());
132+
133+
// launch cuda kernel
134+
int64_t total_num = h_starts[xs_size];
135+
int64_t threads_per_block = std::min(static_cast<int64_t>(1024), total_num);
136+
int64_t elements_per_block =
137+
threads_per_block * 50; // each thread deal with 50 data
138+
int64_t blocks_per_grid =
139+
(total_num + elements_per_block - 1) / elements_per_block;
140+
FusedFillIf<T><<<blocks_per_grid, threads_per_block,
141+
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
142+
d_out_addrs, xs_size, d_starts, static_cast<T>(0), found_inf_data);
80143
}
81144
};
82145

0 commit comments

Comments
 (0)