Skip to content

Commit b2eba11

Browse files
committed
optimize check_finite_and_unscale_op by fused kernel, test=develop
1 parent d709fcd commit b2eba11

File tree

1 file changed

+84
-21
lines changed

1 file changed

+84
-21
lines changed

paddle/fluid/operators/amp/check_finite_and_unscale_op.cu

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,48 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
2626
}
2727

2828
template <typename T, typename MT>
29-
__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num,
30-
bool* found_inf, T* out) {
31-
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
32-
33-
if (idx < num) {
34-
MT val = static_cast<MT>(in[idx]) * (*scale);
29+
__global__ void CheckFiniteAndUnscale(const T** xs, const MT* scale,
30+
int64_t size, int64_t* starts,
31+
bool* found_inf, T** outs) {
32+
const int64_t tid = threadIdx.x + blockIdx.x * blockDim.x;
33+
34+
// copy starts array from global memory to shared memory
35+
extern __shared__ int64_t s_starts[];
36+
for (int i = threadIdx.x; i <= size; i += blockDim.x) {
37+
s_starts[i] = starts[i];
38+
}
39+
__syncthreads();
40+
41+
const int64_t num = s_starts[size];
42+
int pre_xs_index = 0;
43+
bool t_found_inf = false;
44+
const MT t_scale = *scale;
45+
for (int64_t idx = tid; idx < num; idx += gridDim.x * blockDim.x) {
46+
// get the xs's index of thread
47+
int xs_index = pre_xs_index;
48+
while (idx < s_starts[xs_index]) xs_index++;
49+
// avoid some tensor's numel is zero
50+
while (idx >= s_starts[xs_index]) xs_index++;
51+
pre_xs_index = xs_index - 1;
52+
53+
// get in data and out data
54+
const T* in = xs[pre_xs_index];
55+
T* out = outs[pre_xs_index];
56+
int64_t in_idx = idx - s_starts[pre_xs_index];
57+
58+
// Unscale
59+
MT val = static_cast<MT>(in[in_idx]) * t_scale;
3560
T narrow_val = static_cast<T>(val);
36-
out[idx] = narrow_val;
61+
out[in_idx] = narrow_val;
62+
63+
// CheckFinite
3764
if (!isfinite(narrow_val)) {
38-
*found_inf = true;
65+
t_found_inf = true;
3966
}
4067
}
68+
if (t_found_inf) {
69+
*found_inf = true;
70+
}
4171
}
4272

4373
template <typename T>
@@ -63,20 +93,53 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
6393
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
6494
scale_data, inverse_scale_v, found_inf_data);
6595

66-
for (size_t i = 0; i < xs.size(); ++i) {
67-
const auto* x = xs[i];
68-
auto* out = outs[i];
69-
const T* x_data = x->data<T>();
70-
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
71-
72-
int num = x->numel();
73-
int block = 1024;
74-
int grid = (num + block - 1) / block;
75-
VLOG(3) << "launch kernel";
76-
CheckFiniteAndUnscale<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
77-
x_data, inverse_scale_v, num, found_inf_data, out_data);
78-
VLOG(3) << "finish kernel";
96+
size_t xs_size = xs.size();
97+
// calculate each tensor's start index and copy to device
98+
auto h_starts_tensor =
99+
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
100+
int64_t* h_starts = reinterpret_cast<int64_t*>(h_starts_tensor->ptr());
101+
102+
auto d_starts_tensor =
103+
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
104+
int64_t* d_starts = reinterpret_cast<int64_t*>(d_starts_tensor->ptr());
105+
106+
h_starts[0] = 0;
107+
for (int i = 1; i <= xs_size; i++) {
108+
// the start index value of each tensor is
109+
// the sum of previous tensor's size
110+
h_starts[i] = h_starts[i - 1] + xs[i - 1]->numel();
111+
}
112+
int64_t total_num = h_starts[xs_size];
113+
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
114+
d_starts, platform::CPUPlace(), h_starts,
115+
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream());
116+
117+
// copy each tensor's data address to device
118+
auto h_mem = memory::Alloc(platform::CPUPlace(), 2 * xs_size * sizeof(T*));
119+
const T** h_xs = reinterpret_cast<const T**>(h_mem->ptr());
120+
T** h_outs = reinterpret_cast<T**>(h_mem->ptr()) + xs_size;
121+
122+
auto d_mem = memory::Alloc(dev_ctx, 2 * xs_size * sizeof(T*));
123+
const T** d_xs = reinterpret_cast<const T**>(d_mem->ptr());
124+
T** d_outs = reinterpret_cast<T**>(d_mem->ptr()) + xs_size;
125+
126+
for (size_t i = 0; i < xs_size; ++i) {
127+
h_xs[i] = xs[i]->data<T>();
128+
h_outs[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
79129
}
130+
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()), d_xs,
131+
platform::CPUPlace(), h_xs, 2 * xs_size * sizeof(T*),
132+
dev_ctx.stream());
133+
134+
// Launch Kernel
135+
int block = 1024;
136+
int block_num = block * 20; // each thread deal with 20 number
137+
int grid = (total_num + block_num - 1) / block_num;
138+
VLOG(3) << "launch kernel";
139+
CheckFiniteAndUnscale<T, MPDType><<<
140+
grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
141+
d_xs, inverse_scale_v, xs_size, d_starts, found_inf_data, d_outs);
142+
VLOG(3) << "finish kernel";
80143
}
81144
};
82145
} // namespace operators

0 commit comments

Comments
 (0)