Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 73 additions & 15 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,37 @@ __global__ void GpuUpdateLossScaling(
}

template <typename T>
__global__ void FillIf(T* data, const int64_t num, const T value,
const bool* has_inf) {
if (*has_inf) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < num; i += blockDim.x * gridDim.x) {
data[i] = value;
}
__global__ void FusedFillIf(T** outs, const size_t xs_size,
const int64_t* starts, const T value,
const bool* has_inf) {
if (!(*has_inf)) return;

const int tid = threadIdx.x + blockIdx.x * blockDim.x;

// copy starts array from global memory to shared memory
extern __shared__ int64_t starts_s[];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starts_s --> s_starts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int i = threadIdx.x; i <= xs_size; i += blockDim.x) {
starts_s[i] = starts[i];
}
__syncthreads();

const int64_t total_num = starts_s[xs_size];
int out_index = 0;

for (int64_t id = tid; id < total_num; id += blockDim.x * gridDim.x) {
// get the "out" index of "id"
int next_out_index = out_index;
while (id < starts_s[next_out_index]) next_out_index++;
Copy link
Contributor

@wzzju wzzju Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code in line 57 will not be triggered forever.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经验证,这一行的确不会被走到,已删除

// avoid some tensor's numel is zero
while (id >= starts_s[next_out_index]) next_out_index++;
out_index = next_out_index - 1;

// get data pointer and index
T* out_data = outs[out_index];
int64_t idx = id - starts_s[out_index];

// set value
out_data[idx] = value;
}
}

Expand Down Expand Up @@ -68,15 +92,49 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int64_t num = out->numel();
int block = 1024;
int grid = (block - 1 + num) / block;
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(
out_data, num, static_cast<T>(0), found_inf_data);
size_t xs_size = xs.size();
// alloc each tensor's start index and copy to device
auto starts_h_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可构造一次platform::CPUPlace()对象,后续使用,不需要多次构建临时对象。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

int64_t* starts_h = reinterpret_cast<int64_t*>(starts_h_tensor->ptr());

auto starts_d_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* starts_d = reinterpret_cast<int64_t*>(starts_d_tensor->ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions about variable names:
starts_h_tensor --> h_in_starts_mem
starts_h --> h_in_starts
starts_d_tensor --> d_in_starts_mem
starts_d --> d_in_starts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


starts_h[0] = 0;
for (int i = 0; i < xs_size; i++) {
// the start index value of each tensor is
// the sum of previous tensor's size
starts_h[i + 1] = starts_h[i] + outs[i]->numel();
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
starts_d, platform::CPUPlace(), starts_h,
(xs_size + 1) * sizeof(int64_t), dev_ctx.stream());

// copy each tensor of "outs" data address array to device
auto outs_addr_h_tensor =
memory::Alloc(platform::CPUPlace(), xs_size * sizeof(T*));
T** outs_addr_h = reinterpret_cast<T**>(outs_addr_h_tensor->ptr());

auto outs_addr_d_tensor = memory::Alloc(dev_ctx, xs_size * sizeof(T*));
T** outs_addr_d = reinterpret_cast<T**>(outs_addr_d_tensor->ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions about variable names:
outs_addr_h_tensor --> h_out_addrs_mem
outs_addr_h --> h_out_addrs
outs_addr_d_tensor --> d_out_addrs_mem
outs_addr_d --> d_out_addrs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


for (size_t i = 0; i < xs_size; ++i) {
outs_addr_h[i] = outs[i]->mutable_data<T>(dev_ctx.GetPlace());
}
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
outs_addr_d, platform::CPUPlace(), outs_addr_h,
xs_size * sizeof(T*), dev_ctx.stream());

// launch cuda kernel
int64_t total_num = starts_h[xs_size];
int64_t block = std::min(static_cast<int64_t>(1024), total_num);
int64_t block_num = block * 50; // each thread deal with 50 data
int64_t grid = (total_num + block_num - 1) / block_num;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block --> threads_per_block
block_num --> elements_per_block
grid --> blocks_per_grid

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

FusedFillIf<
T><<<grid, block, (xs_size + 1) * sizeof(int64_t), dev_ctx.stream()>>>(
outs_addr_d, xs_size, starts_d, static_cast<T>(0), found_inf_data);
}
};

Expand Down