From bf318b8b5481cd994e0d1c0dc29482fc8e12da8b Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Mon, 23 Aug 2021 08:16:17 +0000 Subject: [PATCH 01/18] add a fusion op: fused_residual_dropout_bias --- paddle/fluid/operators/fused/CMakeLists.txt | 5 + .../fused/fused_residual_dropout_bias.h | 558 ++++++++++++++++++ .../fused/test_fused_residual_dropout_bias.cu | 441 ++++++++++++++ 3 files changed, 1004 insertions(+) create mode 100644 paddle/fluid/operators/fused/fused_residual_dropout_bias.h create mode 100644 paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 541e5afdf9b71e..525f6504f9fa61 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -71,4 +71,9 @@ if (WITH_GPU OR WITH_ROCM) op_library(fused_bn_add_activation_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n") endif() + # fused_dropout + # only support CUDA + if(NOT WITH_ROCM) + nv_test(test_fused_residual_dropout_bias SRCS test_fused_residual_dropout_bias.cu DEPS tensor op_registry elementwise_add_op dropout_op device_context generator) + endif() endif() diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h new file mode 100644 index 00000000000000..e0b51be9e909e3 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -0,0 +1,558 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" + +const int VecSize = 4; + +namespace paddle { +namespace operators { + +namespace platform = paddle::platform; + +inline std::pair GetResidualDropoutThreads( + const platform::CUDADeviceContext &ctx, const uint64_t n) { + const uint64_t tmp_n = n / VecSize; + int threads = std::max( + (uint64_t)32, std::min(tmp_n, (uint64_t)ctx.GetMaxThreadsPerBlock())); + int blocks = std::max((uint64_t)1, (tmp_n + threads - 1) / threads); + return std::pair{threads, blocks}; +} + +inline std::pair GetResidualDropoutBiasThreads( + const platform::CUDADeviceContext &ctx, const uint32_t rows, + const uint32_t cols) { + const uint32_t tmp_cols = cols / VecSize; + int threads = std::max( + (uint32_t)32, std::min(tmp_cols, (uint32_t)ctx.GetMaxThreadsPerBlock())); + int blocks_x = std::max((uint32_t)1, (tmp_cols + threads - 1) / threads); + int blocks_y = std::max((uint32_t)1, rows); + dim3 block_dim(threads, 1, 1); + dim3 grid_dim(blocks_x, blocks_y, 1); + return std::pair{block_dim, grid_dim}; +} + +/********Forward**************/ +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(T) * Size) AlignedVector { + T val[Size]; +}; + +template +inline int VectorizedSize(const T *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4 = std::alignment_of>::value; // NOLINT + if (address % vec4 == 0) { + return 4; + } + return 1; +} + +/** + * dst = residual + dropout(src + bias); + */ +template +__global__ void FusedResidualDropoutBias(const size_t rows, const size_t cols, + uint64_t seed, + const float dropout_prob, + const bool is_upscale_in_train, + const T *src, const T *residual, + const T *bias, MaskType *mask_data, + T *dst, uint64_t increment) { + int col_id = blockDim.x * blockIdx.x + threadIdx.x; + int row_id = blockIdx.y; + int idx = row_id * cols + col_id; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); + + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + const int tmp_cols = cols / VecSize * VecSize; + for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { + for (int i = col_id * VecSize; i < tmp_cols; + i += blockDim.x * gridDim.x * VecSize) { + float4 rand = curand_uniform4(&state); + float *rand_data = &(rand.x); + MaskType mask[VecSize]; + T bias_vec[VecSize]; +#pragma unroll + for (int j = 0; j < VecSize; j++) { + mask[j] = (MaskType)(rand_data[j] > dropout_prob); + bias_vec[j] = bias != nullptr ? bias[i + j] : static_cast(0); + } +#pragma unroll + for (int j = 0; j < VecSize; j++) { + mask_data[r * cols + i + j] = mask[j]; + } + + if (is_upscale_in_train) { +#pragma unroll + for (int j = 0; j < VecSize; j++) { + dst[r * cols + i + j] = (src[r * cols + i + j] + bias_vec[j]) * + static_cast(mask[j]) * factor + + residual[r * cols + i + j]; + } + } else { +#pragma unroll + for (int j = 0; j < VecSize; j++) { + dst[r * cols + i + j] = + (src[r * cols + i + j] + bias_vec[j]) * static_cast(mask[j]) + + residual[r * cols + i + j]; + } + } + } + + int high_index = tmp_cols + col_id; + if (high_index < cols) { + float4 rand = curand_uniform4(&state); + float *rand_data = &(rand.x); + int k = 0; + if (is_upscale_in_train) { + for (int i = high_index; i < cols; i++) { + MaskType m = (MaskType)(rand_data[k++] > dropout_prob); + mask_data[r * cols + i] = m; + dst[r * cols + i] = + (src[r * cols + i] + + (bias != nullptr ? bias[i] : static_cast(0.0))) * + static_cast(m) * factor + + residual[r * cols + i]; + } + } else { + for (int i = high_index; i < cols; i++) { + MaskType m = (MaskType)(rand_data[k++] > dropout_prob); + mask_data[r * cols + i] = m; + dst[r * cols + i] = + (src[r * cols + i] + + (bias != nullptr ? bias[i] : static_cast(0.0))) * + static_cast(m) + + residual[r * cols + i]; + } + } + } + } +} + +template +__global__ void FusedResidualDropoutBiasVec(const size_t rows, + const size_t cols, uint64_t seed, + const float dropout_prob, + const bool is_upscale_in_train, + const T *src, const T *residual, + const T *bias, MaskType *mask_data, + T *dst, uint64_t increment) { + int col_id = blockDim.x * blockIdx.x + threadIdx.x; + int row_id = blockIdx.y; + int idx = row_id * cols + col_id; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); + + T dest; + MaskType mask; + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + using LoadT = AlignedVector; + using MaskLoadT = AlignedVector; + for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { + for (int i = col_id * VecSize; i < cols; + i += blockDim.x * gridDim.x * VecSize) { + T src_vec[VecSize]; + T residual_vec[VecSize]; + T bias_vec[VecSize]; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + bias_vec[ii] = static_cast(0); + } + LoadT *value = reinterpret_cast(&src_vec); + LoadT *residual_value = reinterpret_cast(&residual_vec); + *value = *reinterpret_cast(&src[r * cols + i]); + *residual_value = + *reinterpret_cast(&residual[r * cols + i]); + + LoadT *bias_value = + bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; + if (bias != nullptr) + *bias_value = *reinterpret_cast(&bias[i]); + + float4 rand = curand_uniform4(&state); + T dest_vec[VecSize]; + MaskType mask_vec[VecSize]; + +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = (MaskType)((&rand.x)[ii] >= dropout_prob); + } + + if (is_upscale_in_train) { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * + static_cast(mask_vec[ii]) * factor + + residual_vec[ii]; + } + } else { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dest_vec[ii] = + (src_vec[ii] + bias_vec[ii]) * static_cast(mask_vec[ii]) + + residual_vec[ii]; + } + } + *(reinterpret_cast(&dst[r * cols + i])) = + *reinterpret_cast(&dest_vec[0]); + *(reinterpret_cast(&mask_data[r * cols + i])) = + *reinterpret_cast(&mask_vec[0]); + } + } +} + +template +__global__ void FusedResidualDropoutBiasTest(const size_t rows, + const size_t cols, + const float dropout_prob, + const bool is_upscale_in_train, + const T *src, const T *residual, + const T *bias, T *dst) { + int col_id = blockDim.x * blockIdx.x + threadIdx.x; + int row_id = blockIdx.y; + int idx = row_id * cols + col_id; + + T factor = static_cast(1.0f - dropout_prob); + const int tmp_cols = cols / VecSize * VecSize; + for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { + for (int i = col_id * VecSize; i < tmp_cols; + i += blockDim.x * gridDim.x * VecSize) { + if (is_upscale_in_train) { +#pragma unroll + for (int j = 0; j < VecSize; j++) { + dst[r * cols + i + j] = + (src[r * cols + i + j] + + (bias != nullptr ? bias[i + j] : static_cast(0.0))) + + residual[r * cols + i + j]; + } + } else { +#pragma unroll + for (int j = 0; j < VecSize; j++) { + dst[r * cols + i + j] = + (src[r * cols + i + j] + + (bias != nullptr ? bias[i + j] : static_cast(0.0))) * + factor + + residual[r * cols + i + j]; + } + } + } + + int high_index = tmp_cols + col_id; + if (high_index < cols) { + if (is_upscale_in_train) { + for (int i = high_index; i < cols; i++) { + dst[r * cols + i] = + (src[r * cols + i] + + (bias != nullptr ? bias[i] : static_cast(0.0))) + + residual[r * cols + i]; + } + } else { + for (int i = high_index; i < cols; i++) { + dst[r * cols + i] = + (src[r * cols + i] + + (bias != nullptr ? bias[i] : static_cast(0.0))) * + factor + + residual[r * cols + i]; + } + } + } + } +} + +/** + * dst = residual + dropout(src + bias); + */ +template +void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, + const int increment, uint64_t seed, + const float dropout_prob, + bool is_upscale_in_train, const T *src, + const T *residual, const T *bias, + MaskType *mask_data, T *dst, + const platform::CUDADeviceContext &ctx) { + if (std::abs(dropout_prob - 1.0) < 1e-5) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpyAsync(dst, residual, rows * cols * sizeof(T), + cudaMemcpyDeviceToDevice, ctx.stream())); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( + mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + return; + } + + auto threads = GetResidualDropoutBiasThreads(ctx, rows, cols); + if (cols % VecSize != 0) + FusedResidualDropoutBias< + T, uint8_t, + VecSize><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment); + else + FusedResidualDropoutBiasVec< + T, uint8_t, + VecSize><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment); +} + +template +void LaunchResidualDropoutBiasTest(const uint32_t rows, const uint32_t cols, + const float dropout_prob, + bool is_upscale_in_train, const T *src, + const T *residual, const T *bias, T *dst, + const platform::CUDADeviceContext &ctx) { + if (std::abs(dropout_prob - 1.0) < 1e-5) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpyAsync(dst, residual, rows * cols * sizeof(T), + cudaMemcpyDeviceToDevice, ctx.stream())); + return; + } + auto threads = GetResidualDropoutBiasThreads(ctx, rows, cols); + FusedResidualDropoutBiasTest< + T, VecSize><<>>( + rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, dst); +} + +/********Backward**************/ +template +__global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask, + const T factor, const int64_t size, + T *dx, bool is_upscale_in_train) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + + int tmp_size = size / VecSize * VecSize; + for (int i = idx * VecSize; i < tmp_size; + i += blockDim.x * gridDim.x * VecSize) { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dx[i + ii] = dout[i + ii] * static_cast(mask[i + ii]) * factor; + } + } + + int high_index = tmp_size + idx; + if (size > high_index) { + for (int i = high_index; i < size; i++) { + if (is_upscale_in_train) + dx[i] = dout[i] * static_cast(mask[i]) * factor; + else + dx[i] = dout[i] * static_cast(mask[i]); + } + } +} + +template +__global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, + const T factor, const int64_t size, + T *dx, bool is_upscale_in_train) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + + using LoadT = AlignedVector; + using MaskLoadT = AlignedVector; + for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { + T dout_vec[VecSize]; + MaskType mask_vec[VecSize]; + LoadT *dout_value = reinterpret_cast(&dout_vec); + MaskLoadT *mask_value = reinterpret_cast(&mask_vec); + *dout_value = *reinterpret_cast(&dout[i]); + *mask_value = *reinterpret_cast(&mask[i]); + + T dx_vec[VecSize]; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dx_vec[ii] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + } + *(reinterpret_cast(&dx[i])) = + *reinterpret_cast(&dx_vec[0]); + } +} + +template +__global__ void FusedResidualDropoutBiasGrad( + const T *dout, const MaskType *mask, const T factor, const int64_t rows, + const int64_t cols, T *dx, T *dbias, bool is_upscale_in_train) { + int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ T cache[BSX][BSY]; + T tmp_sum = static_cast(0); + if (col_id < cols) { + for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { + int index = row_id * cols + col_id; + T out_value = dout[index]; + if (is_upscale_in_train) + dx[index] = out_value * static_cast(mask[index]) * factor; + else + dx[index] = out_value * static_cast(mask[index]); + tmp_sum += out_value; + } + } + cache[threadIdx.x][threadIdx.y] = tmp_sum; + __syncthreads(); + + // reduce sum + // TODO(zhangkaihuo) : Replace with ModuleAPI + T sum = static_cast(0); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid / BSY; + int y = tid & (BSY - 1); + + int s = BSY / 2; + while (s > 0) { + if (y < s) { + cache[x][y] += cache[x][y + s]; + } + s /= 2; + __syncthreads(); + } + + if (threadIdx.y == 0 && col_id < cols) { + dbias[col_id] = cache[threadIdx.x][0]; + } +} + +template +__global__ void FusedResidualDropoutBiasGradVec( + const T *dout, const MaskType *mask, const T factor, const int64_t rows, + const int64_t cols, T *dx, T *dbias, bool is_upscale_in_train) { + int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; + + using LoadT = AlignedVector; + using MaskLoadT = AlignedVector; + + T tmp_sum[VecSize] = {static_cast(0)}; + if (col_id * 4 < cols) { + for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { + int index = row_id * cols + col_id * 4; + T out_vec[VecSize]; + MaskType mask_vec[VecSize]; + T dx_vec[VecSize]; + LoadT *out_value = reinterpret_cast(&out_vec); + MaskLoadT *mask_value = reinterpret_cast(&mask_vec); + LoadT *dx_value = reinterpret_cast(&dx_vec); + *out_value = *reinterpret_cast(&dout[index]); + *mask_value = *reinterpret_cast(&mask[index]); + + if (is_upscale_in_train) { +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]) * factor; + tmp_sum[i] += out_vec[i]; + } + } else { +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]); + tmp_sum[i] += out_vec[i]; + } + } + + *(reinterpret_cast(&dx[index])) = + *reinterpret_cast(&dx_vec[0]); + } + } + + __shared__ T cache[BSX * VecSize][BSY]; + for (int i = 0; i < VecSize; i++) + cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; + __syncthreads(); + + // reduce sum + // TODO(zhangkaihuo) : Replace with ModuleAPI + T sum = static_cast(0); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid / BSY; + int y = tid & (BSY - 1); + + int s = BSY / 2; + while (s > 0) { + if (y < s) { + for (int i = 0; i < VecSize; i++) { + cache[x * VecSize + i][y] += cache[x * VecSize + i][y + s]; + } + } + s /= 2; + __syncthreads(); + } + + if (threadIdx.y == 0 && col_id * VecSize < cols) { + for (int i = 0; i < VecSize; i++) + dbias[col_id * VecSize + i] = cache[threadIdx.x * VecSize + i][0]; + } +} + +template +void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, + const float dropout_prob, + const bool is_upscale_in_train, + const uint32_t rows, const uint32_t cols, + T *dx, T *dbias, + const platform::CUDADeviceContext &ctx) { + const T zero = static_cast(0.0); + auto factor = dropout_prob == static_cast(1.0) + ? zero + : static_cast(1.0 / (1.0 - dropout_prob)); + + if (dbias != nullptr) { + if (cols % 4 == 0) { + auto threads = std::min(cols / VecSize, static_cast(8)); + auto blocks = std::max((uint32_t)1, + std::min((cols / VecSize + threads - 1) / threads, + (uint32_t)ctx.GetSMCount())); + dim3 block_dim(threads, 128, 1); + dim3 grid_dim(blocks, 1, 1); + FusedResidualDropoutBiasGradVec< + T, MaskType, 8, 128, + VecSize><<>>( + dout, mask, factor, rows, cols, dx, dbias, is_upscale_in_train); + + } else { + auto threads = std::min(cols, static_cast(8)); + auto blocks = std::max( + (uint32_t)1, + std::min((cols + threads - 1) / threads, (uint32_t)ctx.GetSMCount())); + dim3 block_dim(threads, 128, 1); + dim3 grid_dim(blocks, 1, 1); + FusedResidualDropoutBiasGrad< + T, MaskType, 8, 128><<>>( + dout, mask, factor, rows, cols, dx, dbias, is_upscale_in_train); + } + } else { + const uint64_t n = rows * cols; + auto threads = GetResidualDropoutThreads(ctx, n); + if (n % 4 == 0) { + FusedResidualDropoutGradVec< + T, MaskType, + VecSize><<>>( + dout, mask, factor, n, dx, is_upscale_in_train); + } else { + FusedResidualDropoutGrad< + T, MaskType><<>>( + dout, mask, factor, n, dx, is_upscale_in_train); + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu b/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu new file mode 100644 index 00000000000000..c8a1485ab7e0fa --- /dev/null +++ b/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu @@ -0,0 +1,441 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; + +USE_OP(elementwise_add); +USE_OP(dropout); + +template +void Dropout(const std::vector &x, const framework::DDim &x_dim, + std::vector *out, std::vector *mask, + const platform::CUDADeviceContext &ctx, uint64_t seed, + float dropout_prob, bool is_upscale_in_train, bool is_test) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + tensor_x->Resize(x_dim); + tensor_x->mutable_data(ctx.GetPlace()); + cudaMemcpy(tensor_x->data(), x.data(), x_dim[0] * x_dim[1] * sizeof(T), + cudaMemcpyHostToDevice); + + auto var_out = scope.Var("Out"); + auto tensor_out = var_out->GetMutable(); + + auto var_mask = scope.Var("Mask"); + auto tensor_mask = var_mask->GetMutable(); + + framework::AttributeMap attrs; + attrs.insert({"fix_seed", 1}); + attrs.insert({"seed", static_cast(seed)}); + attrs.insert({"dropout_prob", dropout_prob}); + if (is_upscale_in_train) { + attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); + } + if (is_test) { + attrs.insert({"is_test", 1}); + } + + auto op = framework::OpRegistry::CreateOp( + "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + cudaMemcpy((*out).data(), tensor_out->data(), + x_dim[0] * x_dim[1] * sizeof(T), cudaMemcpyDeviceToHost); + if (!is_test) { + cudaMemcpy((*mask).data(), tensor_mask->data(), + x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost); + } + ctx.Wait(); +} + +template +void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, + const std::vector &dout, const std::vector &mask, + const platform::CUDADeviceContext &ctx, float dropout_prob, + bool is_upscale_in_train) { + framework::Scope scope; + const size_t n = x_dim[0] * x_dim[1]; + auto var_out = scope.Var("DOut"); + auto tensor_out = var_out->GetMutable(); + tensor_out->Resize(x_dim); + tensor_out->mutable_data(ctx.GetPlace()); + cudaMemcpy(tensor_out->data(), dout.data(), n * sizeof(T), + cudaMemcpyHostToDevice); + + auto var_mask = scope.Var("Mask"); + auto tensor_mask = var_mask->GetMutable(); + tensor_mask->Resize(x_dim); + tensor_mask->mutable_data(ctx.GetPlace()); + cudaMemcpy(tensor_mask->data(), mask.data(), n * sizeof(uint8_t), + cudaMemcpyHostToDevice); + + auto var_dx = scope.Var("DX"); + auto tensor_dx = var_dx->GetMutable(); + + framework::AttributeMap attrs; + attrs.insert({"dropout_prob", dropout_prob}); + attrs.insert({"is_test", 0}); + if (is_upscale_in_train) { + attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); + } else { + attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")}); + } + + auto op = framework::OpRegistry::CreateOp( + "dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}}, + {{"X@GRAD", {"DX"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + cudaMemcpy((*dx).data(), tensor_dx->data(), + x_dim[0] * x_dim[1] * sizeof(T), cudaMemcpyDeviceToHost); + ctx.Wait(); +} + +template +struct TestFusedResidualDropoutBias { + uint32_t _rows; + uint32_t _cols; + uint64_t _seed; + float _dropout_prob; + bool _is_upscale_in_train; + bool _is_test; // default false, Set to true for inference only + bool _has_bias = true; + framework::Tensor _src, _residual, _bias, _out, _mask; + framework::Tensor _dsrc, _dbias; + + std::vector _src_vec, _residual_vec, _bias_vec, _out_vec, _mask_vec; + std::vector _correct_out, _correct_dsrc, _correct_dbias; + std::vector _correct_mask; + + platform::CUDAPlace _place; + platform::CUDADeviceContext *_ctx; + + TestFusedResidualDropoutBias() { + _rows = 32; + _cols = 32; + _seed = 0; + _dropout_prob = 0.0; + _is_upscale_in_train = false; + _is_test = false; + _has_bias = true; + _ctx = new platform::CUDADeviceContext(_place); + } + + TestFusedResidualDropoutBias(int rows, int cols, uint64_t seed = 0, + float dropout_prob = 0.0, + bool is_upscale_in_train = false, + bool is_test = false) { + _rows = rows; + _cols = cols; + _seed = seed; + _dropout_prob = dropout_prob; + _is_upscale_in_train = is_upscale_in_train; + _is_test = is_test; + _has_bias = true; + _ctx = new platform::CUDADeviceContext(_place); + } + + ~TestFusedResidualDropoutBias() { delete _ctx; } + + void SetUp() { + const int n = _rows * _cols; + _correct_out.resize(n); + _correct_mask.resize(n); + _correct_dsrc.resize(n); + _correct_dbias.resize(_cols); + + _src_vec.resize(n); + _residual_vec.resize(n); + _bias_vec.resize(_cols); + std::default_random_engine random(time(NULL)); + std::uniform_real_distribution dis(0.0, 1.0); + + for (int i = 0; i < _rows; i++) { + for (int j = 0; j < _cols; j++) { + _src_vec[i * _cols + j] = static_cast(dis(random)); + _residual_vec[i * _cols + j] = static_cast(dis(random)); + if (i == 0) _bias_vec[j] = dis(random); + } + } + + framework::TensorFromVector(_src_vec, *_ctx, &_src); + _src.Resize({_rows, _cols}); + framework::TensorFromVector(_residual_vec, *_ctx, &_residual); + _residual.Resize({_rows, _cols}); + if (_has_bias) { + framework::TensorFromVector(_bias_vec, *_ctx, &_bias); + _bias.Resize({_cols}); + } + + { + _out.Resize({_rows, _cols}); + _out.mutable_data(_place); + _mask.Resize({_rows, _cols}); + _mask.mutable_data(_place); + _dsrc.Resize({_rows, _cols}); + _dsrc.mutable_data(_place); + + if (_has_bias) { + _dbias.Resize({_cols}); + _dbias.mutable_data(_place); + } + } + } + + void BaseForward() { + std::vector out1(_rows * _cols), out2(_rows * _cols); + if (_has_bias) { + for (int i = 0; i < _rows; i++) { + for (int j = 0; j < _cols; j++) { + out1[i * _cols + j] = _src_vec[i * _cols + j] + _bias_vec[j]; + } + } + Dropout(out1, _src.dims(), &out2, &_correct_mask, *_ctx, _seed, + _dropout_prob, _is_upscale_in_train, _is_test); + } else { + Dropout(_src_vec, _src.dims(), &out2, &_correct_mask, *_ctx, _seed, + _dropout_prob, _is_upscale_in_train, _is_test); + } + for (int i = 0; i < _rows; i++) { + for (int j = 0; j < _cols; j++) { + _correct_out[i * _cols + j] = + _residual_vec[i * _cols + j] + out2[i * _cols + j]; + } + } + _ctx->Wait(); + } + + void BaseBackward() { + if (!_is_upscale_in_train) { + for (int i = 0; i < _rows * _cols; i++) { + _correct_dsrc[i] = _correct_out[i] * static_cast(_correct_mask[i]); + } + } else { + DropoutGrad(&_correct_dsrc, _src.dims(), _correct_out, _correct_mask, + *_ctx, _dropout_prob, _is_upscale_in_train); + } + memset(&_correct_dbias[0], 0, _cols * sizeof(T)); + for (int i = 0; i < _rows; i++) { + for (int j = 0; j < _cols; j++) { + _correct_dbias[j] += _correct_out[i * _cols + j]; + } + } + } + + void FusedForward() { + auto threads = paddle::operators::GetResidualDropoutBiasThreads( + *_ctx, (uint64_t)_rows, (uint64_t)_cols); + const int increment = + ((_cols - 1) / (threads.first.x * threads.second.x * VecSize) + 1) * + VecSize; + + T *bias_ptr = nullptr; + if (_has_bias) { + bias_ptr = _bias.data(); + } + if (_is_test) { + paddle::operators::LaunchResidualDropoutBiasTest( + _rows, _cols, _dropout_prob, _is_upscale_in_train, _src.data(), + _residual.data(), bias_ptr, _out.data(), *_ctx); + } else { + paddle::operators::LaunchResidualDropoutBias( + _rows, _cols, increment, _seed, _dropout_prob, _is_upscale_in_train, + _src.data(), _residual.data(), bias_ptr, _mask.data(), + _out.data(), *_ctx); + } + _ctx->Wait(); + } + + void FusedBackward() { + if (_is_test) return; + + T *bias_ptr = nullptr; + if (_has_bias) { + bias_ptr = _dbias.data(); + } + paddle::operators::LaunchResidualDropoutBiasGrad( + _out.data(), _mask.data(), _dropout_prob, + _is_upscale_in_train, _rows, _cols, _dsrc.data(), bias_ptr, *_ctx); + } + + void Run() { + SetUp(); + BaseForward(); + FusedForward(); + BaseBackward(); + FusedBackward(); + } + + void CheckOut(const T diff) { + const int n = _rows * _cols; + std::vector out(n); + std::vector mask(n); + cudaMemcpy(out.data(), _out.data(), _rows * _cols * sizeof(T), + cudaMemcpyDeviceToHost); + if (!_is_test) { + cudaMemcpy(mask.data(), _mask.data(), + _rows * _cols * sizeof(uint8_t), cudaMemcpyDeviceToHost); + } + _ctx->Wait(); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(out[i] - _correct_out[i]), diff); + if (!_is_test) EXPECT_EQ(mask[i], _correct_mask[i]); + } + } + + void CheckGrad(const T diff) { + if (_is_test) return; + + const int n = _rows * _cols; + + std::vector dsrc(n); + cudaMemcpy(dsrc.data(), _dsrc.data(), _rows * _cols * sizeof(T), + cudaMemcpyDeviceToHost); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(dsrc[i] - _correct_dsrc[i]), diff); + } + + if (_has_bias) { + std::vector dbias(_cols); + cudaMemcpy(dbias.data(), _dbias.data(), _cols * sizeof(T), + cudaMemcpyDeviceToHost); + _ctx->Wait(); + for (int i = 0; i < _cols; i++) { + EXPECT_LT(std::abs(dbias[i] - _correct_dbias[i]), diff); + } + } + } +}; + +TEST(FusedDropout, GPUFusedRedisualDorpoutBias) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasDouble) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasFp16) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-2)); + // For inference, check_grad is not required. ref: test_dropout_op.py + // test.CheckGrad((platform::float16)1e-2); +} + +// test no bias and cols % 4 == 0 +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasNoBias) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols); + test._has_bias = false; + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +// test no bias and cols % 4 != 0 +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasNoBias2) { + const int rows = 16; + const int cols = 17; + TestFusedResidualDropoutBias test(rows, cols); + test._has_bias = false; + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +// test add bias and cols % 4 != 0 +TEST(FusedDropout, GPUFusedRedisualDorpoutBias2) { + const int rows = 16; + const int cols = 17; + TestFusedResidualDropoutBias test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBias3) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBias4) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBias5) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, true, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBias6) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 0, 0.35, true, true); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBias7) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 125, 0.0, false, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} From 507117a86e68dd4dc572ff167f286b39d7c84416 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Mon, 23 Aug 2021 13:22:26 +0000 Subject: [PATCH 02/18] simplify the code, andd opt reduce sum --- .../fused/fused_residual_dropout_bias.h | 320 +++++------------- .../fused/test_fused_residual_dropout_bias.cu | 52 ++- 2 files changed, 101 insertions(+), 271 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index e0b51be9e909e3..2d0de22952c88a 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" @@ -30,6 +31,7 @@ namespace paddle { namespace operators { namespace platform = paddle::platform; +namespace cg = cooperative_groups; inline std::pair GetResidualDropoutThreads( const platform::CUDADeviceContext &ctx, const uint64_t n) { @@ -73,86 +75,6 @@ inline int VectorizedSize(const T *pointer) { /** * dst = residual + dropout(src + bias); */ -template -__global__ void FusedResidualDropoutBias(const size_t rows, const size_t cols, - uint64_t seed, - const float dropout_prob, - const bool is_upscale_in_train, - const T *src, const T *residual, - const T *bias, MaskType *mask_data, - T *dst, uint64_t increment) { - int col_id = blockDim.x * blockIdx.x + threadIdx.x; - int row_id = blockIdx.y; - int idx = row_id * cols + col_id; - curandStatePhilox4_32_10_t state; - curand_init(seed, idx, increment, &state); - - T factor = static_cast(1.0f / (1.0f - dropout_prob)); - const int tmp_cols = cols / VecSize * VecSize; - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < tmp_cols; - i += blockDim.x * gridDim.x * VecSize) { - float4 rand = curand_uniform4(&state); - float *rand_data = &(rand.x); - MaskType mask[VecSize]; - T bias_vec[VecSize]; -#pragma unroll - for (int j = 0; j < VecSize; j++) { - mask[j] = (MaskType)(rand_data[j] > dropout_prob); - bias_vec[j] = bias != nullptr ? bias[i + j] : static_cast(0); - } -#pragma unroll - for (int j = 0; j < VecSize; j++) { - mask_data[r * cols + i + j] = mask[j]; - } - - if (is_upscale_in_train) { -#pragma unroll - for (int j = 0; j < VecSize; j++) { - dst[r * cols + i + j] = (src[r * cols + i + j] + bias_vec[j]) * - static_cast(mask[j]) * factor + - residual[r * cols + i + j]; - } - } else { -#pragma unroll - for (int j = 0; j < VecSize; j++) { - dst[r * cols + i + j] = - (src[r * cols + i + j] + bias_vec[j]) * static_cast(mask[j]) + - residual[r * cols + i + j]; - } - } - } - - int high_index = tmp_cols + col_id; - if (high_index < cols) { - float4 rand = curand_uniform4(&state); - float *rand_data = &(rand.x); - int k = 0; - if (is_upscale_in_train) { - for (int i = high_index; i < cols; i++) { - MaskType m = (MaskType)(rand_data[k++] > dropout_prob); - mask_data[r * cols + i] = m; - dst[r * cols + i] = - (src[r * cols + i] + - (bias != nullptr ? bias[i] : static_cast(0.0))) * - static_cast(m) * factor + - residual[r * cols + i]; - } - } else { - for (int i = high_index; i < cols; i++) { - MaskType m = (MaskType)(rand_data[k++] > dropout_prob); - mask_data[r * cols + i] = m; - dst[r * cols + i] = - (src[r * cols + i] + - (bias != nullptr ? bias[i] : static_cast(0.0))) * - static_cast(m) + - residual[r * cols + i]; - } - } - } - } -} - template __global__ void FusedResidualDropoutBiasVec(const size_t rows, const size_t cols, uint64_t seed, @@ -170,6 +92,9 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, T dest; MaskType mask; T factor = static_cast(1.0f / (1.0f - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0); + } using LoadT = AlignedVector; using MaskLoadT = AlignedVector; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { @@ -202,20 +127,11 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, mask_vec[ii] = (MaskType)((&rand.x)[ii] >= dropout_prob); } - if (is_upscale_in_train) { #pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * - static_cast(mask_vec[ii]) * factor + - residual_vec[ii]; - } - } else { -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - dest_vec[ii] = - (src_vec[ii] + bias_vec[ii]) * static_cast(mask_vec[ii]) + - residual_vec[ii]; - } + for (int ii = 0; ii < VecSize; ii++) { + dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * + static_cast(mask_vec[ii]) * factor + + residual_vec[ii]; } *(reinterpret_cast(&dst[r * cols + i])) = *reinterpret_cast(&dest_vec[0]); @@ -237,47 +153,31 @@ __global__ void FusedResidualDropoutBiasTest(const size_t rows, int idx = row_id * cols + col_id; T factor = static_cast(1.0f - dropout_prob); + if (is_upscale_in_train) { + factor = static_cast(1.0); + } const int tmp_cols = cols / VecSize * VecSize; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < tmp_cols; i += blockDim.x * gridDim.x * VecSize) { - if (is_upscale_in_train) { #pragma unroll - for (int j = 0; j < VecSize; j++) { - dst[r * cols + i + j] = - (src[r * cols + i + j] + - (bias != nullptr ? bias[i + j] : static_cast(0.0))) + - residual[r * cols + i + j]; - } - } else { -#pragma unroll - for (int j = 0; j < VecSize; j++) { - dst[r * cols + i + j] = - (src[r * cols + i + j] + - (bias != nullptr ? bias[i + j] : static_cast(0.0))) * - factor + - residual[r * cols + i + j]; - } + for (int j = 0; j < VecSize; j++) { + dst[r * cols + i + j] = + (src[r * cols + i + j] + + (bias != nullptr ? bias[i + j] : static_cast(0.0))) * + factor + + residual[r * cols + i + j]; } } int high_index = tmp_cols + col_id; if (high_index < cols) { - if (is_upscale_in_train) { - for (int i = high_index; i < cols; i++) { - dst[r * cols + i] = - (src[r * cols + i] + - (bias != nullptr ? bias[i] : static_cast(0.0))) + - residual[r * cols + i]; - } - } else { - for (int i = high_index; i < cols; i++) { - dst[r * cols + i] = - (src[r * cols + i] + - (bias != nullptr ? bias[i] : static_cast(0.0))) * - factor + - residual[r * cols + i]; - } + for (int i = high_index; i < cols; i++) { + dst[r * cols + i] = + (src[r * cols + i] + + (bias != nullptr ? bias[i] : static_cast(0.0))) * + factor + + residual[r * cols + i]; } } } @@ -305,9 +205,8 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, auto threads = GetResidualDropoutBiasThreads(ctx, rows, cols); if (cols % VecSize != 0) - FusedResidualDropoutBias< - T, uint8_t, - VecSize><<>>( + FusedResidualDropoutBiasVec< + T, uint8_t, 1><<>>( rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, bias, mask_data, dst, increment); else @@ -337,36 +236,10 @@ void LaunchResidualDropoutBiasTest(const uint32_t rows, const uint32_t cols, } /********Backward**************/ -template -__global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask, - const T factor, const int64_t size, - T *dx, bool is_upscale_in_train) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - - int tmp_size = size / VecSize * VecSize; - for (int i = idx * VecSize; i < tmp_size; - i += blockDim.x * gridDim.x * VecSize) { -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - dx[i + ii] = dout[i + ii] * static_cast(mask[i + ii]) * factor; - } - } - - int high_index = tmp_size + idx; - if (size > high_index) { - for (int i = high_index; i < size; i++) { - if (is_upscale_in_train) - dx[i] = dout[i] * static_cast(mask[i]) * factor; - else - dx[i] = dout[i] * static_cast(mask[i]); - } - } -} - template __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, const T factor, const int64_t size, - T *dx, bool is_upscale_in_train) { + T *dx) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; using LoadT = AlignedVector; @@ -389,62 +262,33 @@ __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, } } -template -__global__ void FusedResidualDropoutBiasGrad( - const T *dout, const MaskType *mask, const T factor, const int64_t rows, - const int64_t cols, T *dx, T *dbias, bool is_upscale_in_train) { - int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; +template +__device__ void reduce_sum(T cache[BSX * VecSize][BSY]) {} - __shared__ T cache[BSX][BSY]; - T tmp_sum = static_cast(0); - if (col_id < cols) { - for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { - int index = row_id * cols + col_id; - T out_value = dout[index]; - if (is_upscale_in_train) - dx[index] = out_value * static_cast(mask[index]) * factor; - else - dx[index] = out_value * static_cast(mask[index]); - tmp_sum += out_value; - } - } - cache[threadIdx.x][threadIdx.y] = tmp_sum; - __syncthreads(); - - // reduce sum - // TODO(zhangkaihuo) : Replace with ModuleAPI - T sum = static_cast(0); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid / BSY; - int y = tid & (BSY - 1); - - int s = BSY / 2; - while (s > 0) { - if (y < s) { - cache[x][y] += cache[x][y + s]; - } - s /= 2; - __syncthreads(); - } - - if (threadIdx.y == 0 && col_id < cols) { - dbias[col_id] = cache[threadIdx.x][0]; +template +static __forceinline__ __device__ U WarpReduceSum(U val) { + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + const int warpSize = 32; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val += paddle::platform::CudaShuffleDownSync(mask, val, offset); } + return val; } template __global__ void FusedResidualDropoutBiasGradVec( const T *dout, const MaskType *mask, const T factor, const int64_t rows, - const int64_t cols, T *dx, T *dbias, bool is_upscale_in_train) { + const int64_t cols, T *dx, T *dbias) { int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; using LoadT = AlignedVector; using MaskLoadT = AlignedVector; T tmp_sum[VecSize] = {static_cast(0)}; - if (col_id * 4 < cols) { + if (col_id * VecSize < cols) { for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { - int index = row_id * cols + col_id * 4; + int index = row_id * cols + col_id * VecSize; T out_vec[VecSize]; MaskType mask_vec[VecSize]; T dx_vec[VecSize]; @@ -454,18 +298,10 @@ __global__ void FusedResidualDropoutBiasGradVec( *out_value = *reinterpret_cast(&dout[index]); *mask_value = *reinterpret_cast(&mask[index]); - if (is_upscale_in_train) { -#pragma unroll - for (int i = 0; i < VecSize; i++) { - dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]) * factor; - tmp_sum[i] += out_vec[i]; - } - } else { #pragma unroll - for (int i = 0; i < VecSize; i++) { - dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]); - tmp_sum[i] += out_vec[i]; - } + for (int i = 0; i < VecSize; i++) { + dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]) * factor; + tmp_sum[i] += out_vec[i]; } *(reinterpret_cast(&dx[index])) = @@ -479,26 +315,23 @@ __global__ void FusedResidualDropoutBiasGradVec( __syncthreads(); // reduce sum - // TODO(zhangkaihuo) : Replace with ModuleAPI T sum = static_cast(0); int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid / BSY; - int y = tid & (BSY - 1); + int x = tid >> 5; + int y = tid & 31; - int s = BSY / 2; - while (s > 0) { - if (y < s) { - for (int i = 0; i < VecSize; i++) { - cache[x * VecSize + i][y] += cache[x * VecSize + i][y + s]; - } + if (x < BSX * VecSize) { +#pragma unroll + for (int i = 0; i < (BSY >> 5); i++) { + sum += cache[x][y + i * 32]; } - s /= 2; - __syncthreads(); } - if (threadIdx.y == 0 && col_id * VecSize < cols) { - for (int i = 0; i < VecSize; i++) - dbias[col_id * VecSize + i] = cache[threadIdx.x * VecSize + i][0]; + sum = WarpReduceSum(sum); + + int bias_id = blockIdx.x * blockDim.x * VecSize + x; + if (y == 0 && x < VecSize * BSX && bias_id < cols) { + dbias[bias_id] = sum; } } @@ -513,43 +346,42 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, auto factor = dropout_prob == static_cast(1.0) ? zero : static_cast(1.0 / (1.0 - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0); + } if (dbias != nullptr) { - if (cols % 4 == 0) { - auto threads = std::min(cols / VecSize, static_cast(8)); - auto blocks = std::max((uint32_t)1, - std::min((cols / VecSize + threads - 1) / threads, - (uint32_t)ctx.GetSMCount())); - dim3 block_dim(threads, 128, 1); - dim3 grid_dim(blocks, 1, 1); + int real_vec_size = VecSize; + if (cols % VecSize != 0) real_vec_size = 1; + auto threads = std::min(cols / real_vec_size, static_cast(8)); + auto blocks = std::max( + (uint32_t)1, std::min((cols / real_vec_size + threads - 1) / threads, + (uint32_t)ctx.GetSMCount())); + dim3 block_dim(threads, 128, 1); + dim3 grid_dim(blocks, 1, 1); + + if (cols % VecSize == 0) { FusedResidualDropoutBiasGradVec< T, MaskType, 8, 128, VecSize><<>>( - dout, mask, factor, rows, cols, dx, dbias, is_upscale_in_train); - + dout, mask, factor, rows, cols, dx, dbias); } else { - auto threads = std::min(cols, static_cast(8)); - auto blocks = std::max( - (uint32_t)1, - std::min((cols + threads - 1) / threads, (uint32_t)ctx.GetSMCount())); - dim3 block_dim(threads, 128, 1); - dim3 grid_dim(blocks, 1, 1); - FusedResidualDropoutBiasGrad< - T, MaskType, 8, 128><<>>( - dout, mask, factor, rows, cols, dx, dbias, is_upscale_in_train); + FusedResidualDropoutBiasGradVec< + T, MaskType, 8, 128, 1><<>>( + dout, mask, factor, rows, cols, dx, dbias); } } else { const uint64_t n = rows * cols; auto threads = GetResidualDropoutThreads(ctx, n); - if (n % 4 == 0) { + if (n % VecSize == 0) { FusedResidualDropoutGradVec< T, MaskType, VecSize><<>>( - dout, mask, factor, n, dx, is_upscale_in_train); + dout, mask, factor, n, dx); } else { - FusedResidualDropoutGrad< - T, MaskType><<>>( - dout, mask, factor, n, dx, is_upscale_in_train); + FusedResidualDropoutGradVec< + T, MaskType, 1><<>>( + dout, mask, factor, n, dx); } } } diff --git a/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu b/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu index c8a1485ab7e0fa..12c2fd6be68360 100644 --- a/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu +++ b/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu @@ -33,16 +33,16 @@ USE_OP(elementwise_add); USE_OP(dropout); template -void Dropout(const std::vector &x, const framework::DDim &x_dim, - std::vector *out, std::vector *mask, - const platform::CUDADeviceContext &ctx, uint64_t seed, - float dropout_prob, bool is_upscale_in_train, bool is_test) { +void Dropout(const T *x, const framework::DDim &x_dim, T *out, + std::vector *mask, const platform::CUDADeviceContext &ctx, + uint64_t seed, float dropout_prob, bool is_upscale_in_train, + bool is_test) { framework::Scope scope; auto var_x = scope.Var("X"); auto tensor_x = var_x->GetMutable(); tensor_x->Resize(x_dim); tensor_x->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_x->data(), x.data(), x_dim[0] * x_dim[1] * sizeof(T), + cudaMemcpy(tensor_x->data(), x, x_dim[0] * x_dim[1] * sizeof(T), cudaMemcpyHostToDevice); auto var_out = scope.Var("Out"); @@ -65,8 +65,8 @@ void Dropout(const std::vector &x, const framework::DDim &x_dim, auto op = framework::OpRegistry::CreateOp( "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs); op->Run(scope, ctx.GetPlace()); - cudaMemcpy((*out).data(), tensor_out->data(), - x_dim[0] * x_dim[1] * sizeof(T), cudaMemcpyDeviceToHost); + cudaMemcpy(out, tensor_out->data(), x_dim[0] * x_dim[1] * sizeof(T), + cudaMemcpyDeviceToHost); if (!is_test) { cudaMemcpy((*mask).data(), tensor_mask->data(), x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost); @@ -75,24 +75,23 @@ void Dropout(const std::vector &x, const framework::DDim &x_dim, } template -void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, - const std::vector &dout, const std::vector &mask, - const platform::CUDADeviceContext &ctx, float dropout_prob, - bool is_upscale_in_train) { +void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout, + const uint8_t *mask, const platform::CUDADeviceContext &ctx, + float dropout_prob, bool is_upscale_in_train) { framework::Scope scope; const size_t n = x_dim[0] * x_dim[1]; auto var_out = scope.Var("DOut"); auto tensor_out = var_out->GetMutable(); tensor_out->Resize(x_dim); tensor_out->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_out->data(), dout.data(), n * sizeof(T), + cudaMemcpy(tensor_out->data(), dout, n * sizeof(T), cudaMemcpyHostToDevice); auto var_mask = scope.Var("Mask"); auto tensor_mask = var_mask->GetMutable(); tensor_mask->Resize(x_dim); tensor_mask->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_mask->data(), mask.data(), n * sizeof(uint8_t), + cudaMemcpy(tensor_mask->data(), mask, n * sizeof(uint8_t), cudaMemcpyHostToDevice); auto var_dx = scope.Var("DX"); @@ -112,8 +111,8 @@ void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, {{"X@GRAD", {"DX"}}}, attrs); op->Run(scope, ctx.GetPlace()); - cudaMemcpy((*dx).data(), tensor_dx->data(), - x_dim[0] * x_dim[1] * sizeof(T), cudaMemcpyDeviceToHost); + cudaMemcpy(dx, tensor_dx->data(), x_dim[0] * x_dim[1] * sizeof(T), + cudaMemcpyDeviceToHost); ctx.Wait(); } @@ -211,17 +210,20 @@ struct TestFusedResidualDropoutBias { void BaseForward() { std::vector out1(_rows * _cols), out2(_rows * _cols); if (_has_bias) { + // add bias for (int i = 0; i < _rows; i++) { for (int j = 0; j < _cols; j++) { out1[i * _cols + j] = _src_vec[i * _cols + j] + _bias_vec[j]; } } - Dropout(out1, _src.dims(), &out2, &_correct_mask, *_ctx, _seed, - _dropout_prob, _is_upscale_in_train, _is_test); + // call dropout + Dropout(out1.data(), _src.dims(), out2.data(), &_correct_mask, *_ctx, + _seed, _dropout_prob, _is_upscale_in_train, _is_test); } else { - Dropout(_src_vec, _src.dims(), &out2, &_correct_mask, *_ctx, _seed, - _dropout_prob, _is_upscale_in_train, _is_test); + Dropout(_src_vec.data(), _src.dims(), out2.data(), &_correct_mask, + *_ctx, _seed, _dropout_prob, _is_upscale_in_train, _is_test); } + // add residual for (int i = 0; i < _rows; i++) { for (int j = 0; j < _cols; j++) { _correct_out[i * _cols + j] = @@ -232,14 +234,10 @@ struct TestFusedResidualDropoutBias { } void BaseBackward() { - if (!_is_upscale_in_train) { - for (int i = 0; i < _rows * _cols; i++) { - _correct_dsrc[i] = _correct_out[i] * static_cast(_correct_mask[i]); - } - } else { - DropoutGrad(&_correct_dsrc, _src.dims(), _correct_out, _correct_mask, - *_ctx, _dropout_prob, _is_upscale_in_train); - } + DropoutGrad(_correct_dsrc.data(), _src.dims(), _correct_out.data(), + _correct_mask.data(), *_ctx, _dropout_prob, + _is_upscale_in_train); + // calc dbias memset(&_correct_dbias[0], 0, _cols * sizeof(T)); for (int i = 0; i < _rows; i++) { for (int j = 0; j < _cols; j++) { From 462caa1f3226012289c5245b9c478294f9951a91 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Tue, 24 Aug 2021 04:46:40 +0000 Subject: [PATCH 03/18] resolve review comments and add comments to the code --- paddle/fluid/operators/fused/CMakeLists.txt | 2 +- paddle/fluid/operators/fused/fused_dropout.h | 70 ++++++ .../fused/fused_residual_dropout_bias.h | 204 +++++++++--------- ...cu => fused_residual_dropout_bias_test.cu} | 25 ++- 4 files changed, 192 insertions(+), 109 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_dropout.h rename paddle/fluid/operators/fused/{test_fused_residual_dropout_bias.cu => fused_residual_dropout_bias_test.cu} (95%) diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 525f6504f9fa61..78ff136c4d1038 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -74,6 +74,6 @@ if (WITH_GPU OR WITH_ROCM) # fused_dropout # only support CUDA if(NOT WITH_ROCM) - nv_test(test_fused_residual_dropout_bias SRCS test_fused_residual_dropout_bias.cu DEPS tensor op_registry elementwise_add_op dropout_op device_context generator) + nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry elementwise_add_op dropout_op device_context generator) endif() endif() diff --git a/paddle/fluid/operators/fused/fused_dropout.h b/paddle/fluid/operators/fused/fused_dropout.h new file mode 100644 index 00000000000000..bd6a4122f5830d --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include +#include + +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +/** + * get 1D threads and blocks + */ +template +inline std::pair Get1DThreadsAndBlocks( + const platform::CUDADeviceContext &ctx, const uint64_t n) { + const uint64_t tmp_n = n / VecSize; + int threads = std::max( + (uint64_t)32, std::min(tmp_n, (uint64_t)ctx.GetMaxThreadsPerBlock())); + int blocks = std::max((uint64_t)1, (tmp_n + threads - 1) / threads); + return std::pair{threads, blocks}; +} + +/** + * get the threads for fused_residual_dropout_bias: + * 1D blocks: blockDim.x = cols + * 2D grids: gridDim.y = rows + */ +template +inline std::pair Get1DBlocksAnd2DGrids( + const platform::CUDADeviceContext &ctx, const uint32_t rows, + const uint32_t cols) { + const uint32_t tmp_cols = cols / VecSize; + int threads = std::max( + (uint32_t)32, std::min(tmp_cols, (uint32_t)ctx.GetMaxThreadsPerBlock())); + int blocks_x = std::max((uint32_t)1, (tmp_cols + threads - 1) / threads); + int blocks_y = std::max((uint32_t)1, rows); + dim3 block_dim(threads, 1, 1); + dim3 grid_dim(blocks_x, blocks_y, 1); + return std::pair{block_dim, grid_dim}; +} + +// aligned vector generates vectorized load/store on CUDA +template +struct alignas(sizeof(T) * VecSize) AlignedVector { + T val[VecSize]; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 2d0de22952c88a..16747d7739be1e 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,18 +14,7 @@ limitations under the License. */ #pragma once -#include -#include -#include - -#include -#include - -#include "paddle/fluid/platform/cuda_device_function.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/float16.h" - -const int VecSize = 4; +#include "paddle/fluid/operators/fused/fused_dropout.h" namespace paddle { namespace operators { @@ -33,47 +22,11 @@ namespace operators { namespace platform = paddle::platform; namespace cg = cooperative_groups; -inline std::pair GetResidualDropoutThreads( - const platform::CUDADeviceContext &ctx, const uint64_t n) { - const uint64_t tmp_n = n / VecSize; - int threads = std::max( - (uint64_t)32, std::min(tmp_n, (uint64_t)ctx.GetMaxThreadsPerBlock())); - int blocks = std::max((uint64_t)1, (tmp_n + threads - 1) / threads); - return std::pair{threads, blocks}; -} - -inline std::pair GetResidualDropoutBiasThreads( - const platform::CUDADeviceContext &ctx, const uint32_t rows, - const uint32_t cols) { - const uint32_t tmp_cols = cols / VecSize; - int threads = std::max( - (uint32_t)32, std::min(tmp_cols, (uint32_t)ctx.GetMaxThreadsPerBlock())); - int blocks_x = std::max((uint32_t)1, (tmp_cols + threads - 1) / threads); - int blocks_y = std::max((uint32_t)1, rows); - dim3 block_dim(threads, 1, 1); - dim3 grid_dim(blocks_x, blocks_y, 1); - return std::pair{block_dim, grid_dim}; -} - /********Forward**************/ -// aligned vector generates vectorized load/store on CUDA -template -struct alignas(sizeof(T) * Size) AlignedVector { - T val[Size]; -}; - -template -inline int VectorizedSize(const T *pointer) { - uint64_t address = reinterpret_cast(pointer); - constexpr int vec4 = std::alignment_of>::value; // NOLINT - if (address % vec4 == 0) { - return 4; - } - return 1; -} - /** - * dst = residual + dropout(src + bias); + * @brief dst = residual + dropout(src + bias); + * the src, residual, mask and dst shape is (rows, cols) + * the bias shape is (1, cols) */ template __global__ void FusedResidualDropoutBiasVec(const size_t rows, @@ -81,7 +34,7 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, const float dropout_prob, const bool is_upscale_in_train, const T *src, const T *residual, - const T *bias, MaskType *mask_data, + const T *bias, MaskType *mask, T *dst, uint64_t increment) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; @@ -89,11 +42,9 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); - T dest; - MaskType mask; T factor = static_cast(1.0f / (1.0f - dropout_prob)); if (!is_upscale_in_train) { - factor = static_cast(1.0); + factor = static_cast(1.0f); } using LoadT = AlignedVector; using MaskLoadT = AlignedVector; @@ -107,6 +58,7 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, for (int ii = 0; ii < VecSize; ii++) { bias_vec[ii] = static_cast(0); } + // vectorize load data from global LoadT *value = reinterpret_cast(&src_vec); LoadT *residual_value = reinterpret_cast(&residual_vec); *value = *reinterpret_cast(&src[r * cols + i]); @@ -133,58 +85,77 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, static_cast(mask_vec[ii]) * factor + residual_vec[ii]; } + + // store result to global *(reinterpret_cast(&dst[r * cols + i])) = *reinterpret_cast(&dest_vec[0]); - *(reinterpret_cast(&mask_data[r * cols + i])) = + *(reinterpret_cast(&mask[r * cols + i])) = *reinterpret_cast(&mask_vec[0]); } } } +/** + * @brief for dropout's param is_test = true + * the src, residual and dst shape is (rows, cols) + * the bias shape is (1, cols) + */ template -__global__ void FusedResidualDropoutBiasTest(const size_t rows, - const size_t cols, - const float dropout_prob, - const bool is_upscale_in_train, - const T *src, const T *residual, - const T *bias, T *dst) { +__global__ void FusedResidualDropoutBiasIsTest(const size_t rows, + const size_t cols, + const float dropout_prob, + const bool is_upscale_in_train, + const T *src, const T *residual, + const T *bias, T *dst) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; T factor = static_cast(1.0f - dropout_prob); if (is_upscale_in_train) { - factor = static_cast(1.0); + factor = static_cast(1.0f); } + + using LoadT = AlignedVector; + const int tmp_cols = cols / VecSize * VecSize; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < tmp_cols; i += blockDim.x * gridDim.x * VecSize) { + T src_vec[VecSize]; + T residual_vec[VecSize]; + T bias_vec[VecSize]; #pragma unroll - for (int j = 0; j < VecSize; j++) { - dst[r * cols + i + j] = - (src[r * cols + i + j] + - (bias != nullptr ? bias[i + j] : static_cast(0.0))) * - factor + - residual[r * cols + i + j]; + for (int ii = 0; ii < VecSize; ii++) { + bias_vec[ii] = static_cast(0); } - } + // vectorize load data from global + LoadT *value = reinterpret_cast(&src_vec); + LoadT *residual_value = reinterpret_cast(&residual_vec); + *value = *reinterpret_cast(&src[r * cols + i]); + *residual_value = + *reinterpret_cast(&residual[r * cols + i]); - int high_index = tmp_cols + col_id; - if (high_index < cols) { - for (int i = high_index; i < cols; i++) { - dst[r * cols + i] = - (src[r * cols + i] + - (bias != nullptr ? bias[i] : static_cast(0.0))) * - factor + - residual[r * cols + i]; + LoadT *bias_value = + bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; + if (bias != nullptr) + *bias_value = *reinterpret_cast(&bias[i]); + + T dest_vec[VecSize]; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii]; } + + // store result to global + *(reinterpret_cast(&dst[r * cols + i])) = + *reinterpret_cast(&dest_vec[0]); } } } /** - * dst = residual + dropout(src + bias); + * @brief dst = residual + dropout(src + bias); */ template void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, @@ -194,7 +165,8 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, const T *residual, const T *bias, MaskType *mask_data, T *dst, const platform::CUDADeviceContext &ctx) { - if (std::abs(dropout_prob - 1.0) < 1e-5) { + // dropout_prob == 1.0f + if (std::abs(dropout_prob - 1.0f) < 1e-5) { PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemcpyAsync(dst, residual, rows * cols * sizeof(T), cudaMemcpyDeviceToDevice, ctx.stream())); @@ -203,7 +175,8 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, return; } - auto threads = GetResidualDropoutBiasThreads(ctx, rows, cols); + const int VecSize = 4; + auto threads = Get1DBlocksAnd2DGrids(ctx, rows, cols); if (cols % VecSize != 0) FusedResidualDropoutBiasVec< T, uint8_t, 1><<>>( @@ -217,25 +190,39 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, bias, mask_data, dst, increment); } +/** + *@brief to launch kernel FusedResidualDropoutBiasIsTest + */ template -void LaunchResidualDropoutBiasTest(const uint32_t rows, const uint32_t cols, - const float dropout_prob, - bool is_upscale_in_train, const T *src, - const T *residual, const T *bias, T *dst, - const platform::CUDADeviceContext &ctx) { - if (std::abs(dropout_prob - 1.0) < 1e-5) { +void LaunchResidualDropoutBiasIsTest(const uint32_t rows, const uint32_t cols, + const float dropout_prob, + bool is_upscale_in_train, const T *src, + const T *residual, const T *bias, T *dst, + const platform::CUDADeviceContext &ctx) { + if (std::abs(dropout_prob - 1.0f) < 1e-5) { PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemcpyAsync(dst, residual, rows * cols * sizeof(T), cudaMemcpyDeviceToDevice, ctx.stream())); return; } - auto threads = GetResidualDropoutBiasThreads(ctx, rows, cols); - FusedResidualDropoutBiasTest< - T, VecSize><<>>( - rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, dst); + const int VecSize = 4; + auto threads = Get1DBlocksAnd2DGrids(ctx, rows, cols); + if (cols % VecSize != 0) + FusedResidualDropoutBiasIsTest< + T, 1><<>>( + rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, + dst); + else + FusedResidualDropoutBiasIsTest< + T, VecSize><<>>( + rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, + dst); } /********Backward**************/ +/* + * @brief calculate the grad of no bias + */ template __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, const T factor, const int64_t size, @@ -262,9 +249,6 @@ __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, } } -template -__device__ void reduce_sum(T cache[BSX * VecSize][BSY]) {} - template static __forceinline__ __device__ U WarpReduceSum(U val) { unsigned mask = 0u; @@ -276,6 +260,12 @@ static __forceinline__ __device__ U WarpReduceSum(U val) { return val; } +/** + * blocks(128 * 8) + * 1. calculate the dx and reduce total rows to 128 rows + * 2. save 128*8 temporary sum in 8*128 shared memory + * 3. reduce the sum of 128 rows data by 8*VecSize warps + */ template __global__ void FusedResidualDropoutBiasGradVec( const T *dout, const MaskType *mask, const T factor, const int64_t rows, @@ -286,6 +276,7 @@ __global__ void FusedResidualDropoutBiasGradVec( using MaskLoadT = AlignedVector; T tmp_sum[VecSize] = {static_cast(0)}; + // calculate the dx and temporary sum if (col_id * VecSize < cols) { for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { int index = row_id * cols + col_id * VecSize; @@ -309,6 +300,7 @@ __global__ void FusedResidualDropoutBiasGradVec( } } + // save temporary sum to cache and do transpose __shared__ T cache[BSX * VecSize][BSY]; for (int i = 0; i < VecSize; i++) cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; @@ -317,24 +309,31 @@ __global__ void FusedResidualDropoutBiasGradVec( // reduce sum T sum = static_cast(0); int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 5; - int y = tid & 31; + int x = tid >> 5; // warp id + int y = tid & 31; // thread id on warp 0~31 + // need BSX * VecSize warps if (x < BSX * VecSize) { +// reduce 128 to 32 #pragma unroll for (int i = 0; i < (BSY >> 5); i++) { sum += cache[x][y + i * 32]; } } + // reduce 32 to 1 sum = WarpReduceSum(sum); + // save sum to dbias int bias_id = blockIdx.x * blockDim.x * VecSize + x; if (y == 0 && x < VecSize * BSX && bias_id < cols) { dbias[bias_id] = sum; } } +/** + * @brief to launch kernel FusedResidualDropoutBiasGradVec + */ template void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, const float dropout_prob, @@ -342,14 +341,15 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, const uint32_t rows, const uint32_t cols, T *dx, T *dbias, const platform::CUDADeviceContext &ctx) { - const T zero = static_cast(0.0); - auto factor = dropout_prob == static_cast(1.0) + const T zero = static_cast(0.0f); + auto factor = dropout_prob == static_cast(1.0f) ? zero - : static_cast(1.0 / (1.0 - dropout_prob)); + : static_cast(1.0f / (1.0f - dropout_prob)); if (!is_upscale_in_train) { - factor = static_cast(1.0); + factor = static_cast(1.0f); } + const int VecSize = 4; if (dbias != nullptr) { int real_vec_size = VecSize; if (cols % VecSize != 0) real_vec_size = 1; @@ -372,7 +372,7 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, } } else { const uint64_t n = rows * cols; - auto threads = GetResidualDropoutThreads(ctx, n); + auto threads = Get1DThreadsAndBlocks(ctx, n); if (n % VecSize == 0) { FusedResidualDropoutGradVec< T, MaskType, diff --git a/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu similarity index 95% rename from paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu rename to paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index 12c2fd6be68360..5cd20dce57855b 100644 --- a/paddle/fluid/operators/fused/test_fused_residual_dropout_bias.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,9 +29,19 @@ limitations under the License. */ namespace framework = paddle::framework; namespace platform = paddle::platform; -USE_OP(elementwise_add); USE_OP(dropout); +/** + * @brief the unittest of fused_residual_dropout_bias + * 1. random input data + * 2. add bias, call paddle dropout op, add residual, and get the base result + * 3. call FusedResidualDropoutBias function get fused result + * 4. compare ther base result and fused result + */ + +/** + * @brief call paddle dropout op + */ template void Dropout(const T *x, const framework::DDim &x_dim, T *out, std::vector *mask, const platform::CUDADeviceContext &ctx, @@ -74,6 +84,9 @@ void Dropout(const T *x, const framework::DDim &x_dim, T *out, ctx.Wait(); } +/** + * @brief call paddle dropout_grad op + */ template void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout, const uint8_t *mask, const platform::CUDADeviceContext &ctx, @@ -247,8 +260,9 @@ struct TestFusedResidualDropoutBias { } void FusedForward() { - auto threads = paddle::operators::GetResidualDropoutBiasThreads( + auto threads = paddle::operators::Get1DBlocksAnd2DGrids( *_ctx, (uint64_t)_rows, (uint64_t)_cols); + const int VecSize = 4; const int increment = ((_cols - 1) / (threads.first.x * threads.second.x * VecSize) + 1) * VecSize; @@ -258,7 +272,7 @@ struct TestFusedResidualDropoutBias { bias_ptr = _bias.data(); } if (_is_test) { - paddle::operators::LaunchResidualDropoutBiasTest( + paddle::operators::LaunchResidualDropoutBiasIsTest( _rows, _cols, _dropout_prob, _is_upscale_in_train, _src.data(), _residual.data(), bias_ptr, _out.data(), *_ctx); } else { @@ -351,14 +365,13 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBiasDouble) { test.CheckGrad(static_cast(1e-5)); } +// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py TEST(FusedDropout, GPUFusedRedisualDorpoutBiasFp16) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols); test.Run(); test.CheckOut(static_cast(1e-2)); - // For inference, check_grad is not required. ref: test_dropout_op.py - // test.CheckGrad((platform::float16)1e-2); } // test no bias and cols % 4 == 0 From 93e063864f56f15e2e88cae77e6a5b5b633ab3e5 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Tue, 24 Aug 2021 12:35:20 +0000 Subject: [PATCH 04/18] fused_dropout: optimize code structure to facilitate reuse --- paddle/fluid/operators/fused/CMakeLists.txt | 2 +- paddle/fluid/operators/fused/fused_dropout.h | 12 ++ .../operators/fused/fused_dropout_test.h | 121 ++++++++++++++++++ .../fused/fused_residual_dropout_bias.h | 15 +-- .../fused/fused_residual_dropout_bias_test.cu | 100 +-------------- 5 files changed, 136 insertions(+), 114 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_dropout_test.h diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 78ff136c4d1038..f3035cddcba020 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -74,6 +74,6 @@ if (WITH_GPU OR WITH_ROCM) # fused_dropout # only support CUDA if(NOT WITH_ROCM) - nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry elementwise_add_op dropout_op device_context generator) + nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator) endif() endif() diff --git a/paddle/fluid/operators/fused/fused_dropout.h b/paddle/fluid/operators/fused/fused_dropout.h index bd6a4122f5830d..4188d935b9e458 100644 --- a/paddle/fluid/operators/fused/fused_dropout.h +++ b/paddle/fluid/operators/fused/fused_dropout.h @@ -66,5 +66,17 @@ struct alignas(sizeof(T) * VecSize) AlignedVector { T val[VecSize]; }; +// reduce sum by a warp +template +static __forceinline__ __device__ U WarpReduceSum(U val) { + unsigned mask = 0u; + CREATE_SHFL_MASK(mask, true); + const int warpSize = 32; + for (int offset = warpSize / 2; offset > 0; offset /= 2) { + val += paddle::platform::CudaShuffleDownSync(mask, val, offset); + } + return val; +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h new file mode 100644 index 00000000000000..6cb8cd19b608d1 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -0,0 +1,121 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; + +USE_OP(dropout); + +/** + * @brief call paddle dropout op + */ +template +void Dropout(const T *x, const framework::DDim &x_dim, T *out, + std::vector *mask, const platform::CUDADeviceContext &ctx, + uint64_t seed, float dropout_prob, bool is_upscale_in_train, + bool is_test) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + tensor_x->Resize(x_dim); + tensor_x->mutable_data(ctx.GetPlace()); + cudaMemcpy(tensor_x->data(), x, x_dim[0] * x_dim[1] * sizeof(T), + cudaMemcpyHostToDevice); + + auto var_out = scope.Var("Out"); + auto tensor_out = var_out->GetMutable(); + + auto var_mask = scope.Var("Mask"); + auto tensor_mask = var_mask->GetMutable(); + + framework::AttributeMap attrs; + attrs.insert({"fix_seed", 1}); + attrs.insert({"seed", static_cast(seed)}); + attrs.insert({"dropout_prob", dropout_prob}); + if (is_upscale_in_train) { + attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); + } + if (is_test) { + attrs.insert({"is_test", 1}); + } + + auto op = framework::OpRegistry::CreateOp( + "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + cudaMemcpy(out, tensor_out->data(), x_dim[0] * x_dim[1] * sizeof(T), + cudaMemcpyDeviceToHost); + if (!is_test) { + cudaMemcpy((*mask).data(), tensor_mask->data(), + x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost); + } + ctx.Wait(); +} + +/** + * @brief call paddle dropout_grad op + */ +template +void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout, + const uint8_t *mask, const platform::CUDADeviceContext &ctx, + float dropout_prob, bool is_upscale_in_train) { + framework::Scope scope; + const size_t n = x_dim[0] * x_dim[1]; + auto var_out = scope.Var("DOut"); + auto tensor_out = var_out->GetMutable(); + tensor_out->Resize(x_dim); + tensor_out->mutable_data(ctx.GetPlace()); + cudaMemcpy(tensor_out->data(), dout, n * sizeof(T), + cudaMemcpyHostToDevice); + + auto var_mask = scope.Var("Mask"); + auto tensor_mask = var_mask->GetMutable(); + tensor_mask->Resize(x_dim); + tensor_mask->mutable_data(ctx.GetPlace()); + cudaMemcpy(tensor_mask->data(), mask, n * sizeof(uint8_t), + cudaMemcpyHostToDevice); + + auto var_dx = scope.Var("DX"); + auto tensor_dx = var_dx->GetMutable(); + + framework::AttributeMap attrs; + attrs.insert({"dropout_prob", dropout_prob}); + attrs.insert({"is_test", 0}); + if (is_upscale_in_train) { + attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); + } else { + attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")}); + } + + auto op = framework::OpRegistry::CreateOp( + "dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}}, + {{"X@GRAD", {"DX"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + cudaMemcpy(dx, tensor_dx->data(), x_dim[0] * x_dim[1] * sizeof(T), + cudaMemcpyDeviceToHost); + ctx.Wait(); +} diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 16747d7739be1e..ce9273dff0a993 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -118,9 +118,8 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows, using LoadT = AlignedVector; - const int tmp_cols = cols / VecSize * VecSize; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < tmp_cols; + for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { T src_vec[VecSize]; T residual_vec[VecSize]; @@ -249,17 +248,6 @@ __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, } } -template -static __forceinline__ __device__ U WarpReduceSum(U val) { - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, true); - const int warpSize = 32; - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - val += paddle::platform::CudaShuffleDownSync(mask, val, offset); - } - return val; -} - /** * blocks(128 * 8) * 1. calculate the dx and reduce total rows to 128 rows @@ -285,7 +273,6 @@ __global__ void FusedResidualDropoutBiasGradVec( T dx_vec[VecSize]; LoadT *out_value = reinterpret_cast(&out_vec); MaskLoadT *mask_value = reinterpret_cast(&mask_vec); - LoadT *dx_value = reinterpret_cast(&dx_vec); *out_value = *reinterpret_cast(&dout[index]); *mask_value = *reinterpret_cast(&mask[index]); diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index 5cd20dce57855b..fa119f1132e8f6 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -17,20 +17,12 @@ limitations under the License. */ #include #include -#include "gtest/gtest.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/fused/fused_dropout_test.h" #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" -#include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/string/printf.h" namespace framework = paddle::framework; namespace platform = paddle::platform; -USE_OP(dropout); - /** * @brief the unittest of fused_residual_dropout_bias * 1. random input data @@ -39,96 +31,6 @@ USE_OP(dropout); * 4. compare ther base result and fused result */ -/** - * @brief call paddle dropout op - */ -template -void Dropout(const T *x, const framework::DDim &x_dim, T *out, - std::vector *mask, const platform::CUDADeviceContext &ctx, - uint64_t seed, float dropout_prob, bool is_upscale_in_train, - bool is_test) { - framework::Scope scope; - auto var_x = scope.Var("X"); - auto tensor_x = var_x->GetMutable(); - tensor_x->Resize(x_dim); - tensor_x->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_x->data(), x, x_dim[0] * x_dim[1] * sizeof(T), - cudaMemcpyHostToDevice); - - auto var_out = scope.Var("Out"); - auto tensor_out = var_out->GetMutable(); - - auto var_mask = scope.Var("Mask"); - auto tensor_mask = var_mask->GetMutable(); - - framework::AttributeMap attrs; - attrs.insert({"fix_seed", 1}); - attrs.insert({"seed", static_cast(seed)}); - attrs.insert({"dropout_prob", dropout_prob}); - if (is_upscale_in_train) { - attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); - } - if (is_test) { - attrs.insert({"is_test", 1}); - } - - auto op = framework::OpRegistry::CreateOp( - "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs); - op->Run(scope, ctx.GetPlace()); - cudaMemcpy(out, tensor_out->data(), x_dim[0] * x_dim[1] * sizeof(T), - cudaMemcpyDeviceToHost); - if (!is_test) { - cudaMemcpy((*mask).data(), tensor_mask->data(), - x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost); - } - ctx.Wait(); -} - -/** - * @brief call paddle dropout_grad op - */ -template -void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout, - const uint8_t *mask, const platform::CUDADeviceContext &ctx, - float dropout_prob, bool is_upscale_in_train) { - framework::Scope scope; - const size_t n = x_dim[0] * x_dim[1]; - auto var_out = scope.Var("DOut"); - auto tensor_out = var_out->GetMutable(); - tensor_out->Resize(x_dim); - tensor_out->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_out->data(), dout, n * sizeof(T), - cudaMemcpyHostToDevice); - - auto var_mask = scope.Var("Mask"); - auto tensor_mask = var_mask->GetMutable(); - tensor_mask->Resize(x_dim); - tensor_mask->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_mask->data(), mask, n * sizeof(uint8_t), - cudaMemcpyHostToDevice); - - auto var_dx = scope.Var("DX"); - auto tensor_dx = var_dx->GetMutable(); - - framework::AttributeMap attrs; - attrs.insert({"dropout_prob", dropout_prob}); - attrs.insert({"is_test", 0}); - if (is_upscale_in_train) { - attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); - } else { - attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")}); - } - - auto op = framework::OpRegistry::CreateOp( - "dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}}, - {{"X@GRAD", {"DX"}}}, attrs); - op->Run(scope, ctx.GetPlace()); - - cudaMemcpy(dx, tensor_dx->data(), x_dim[0] * x_dim[1] * sizeof(T), - cudaMemcpyDeviceToHost); - ctx.Wait(); -} - template struct TestFusedResidualDropoutBias { uint32_t _rows; From 036b4307f58ddfbdeb7b815b138cd59cb09cd600 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Wed, 25 Aug 2021 11:59:17 +0000 Subject: [PATCH 05/18] optimize code structure to facilitate reuse --- paddle/fluid/operators/fused/fused_dropout.h | 12 - .../fused/fused_residual_dropout_bias.h | 258 ++++++++++-------- .../fused/fused_residual_dropout_bias_test.cu | 38 ++- 3 files changed, 159 insertions(+), 149 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout.h b/paddle/fluid/operators/fused/fused_dropout.h index 4188d935b9e458..bd6a4122f5830d 100644 --- a/paddle/fluid/operators/fused/fused_dropout.h +++ b/paddle/fluid/operators/fused/fused_dropout.h @@ -66,17 +66,5 @@ struct alignas(sizeof(T) * VecSize) AlignedVector { T val[VecSize]; }; -// reduce sum by a warp -template -static __forceinline__ __device__ U WarpReduceSum(U val) { - unsigned mask = 0u; - CREATE_SHFL_MASK(mask, true); - const int warpSize = 32; - for (int offset = warpSize / 2; offset > 0; offset /= 2) { - val += paddle::platform::CudaShuffleDownSync(mask, val, offset); - } - return val; -} - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index ce9273dff0a993..0a263635e46029 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/fused_dropout.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" namespace paddle { namespace operators { @@ -22,7 +23,71 @@ namespace operators { namespace platform = paddle::platform; namespace cg = cooperative_groups; +/** + * @brief fused the add_bias, dropout, add residual into one operators + * + */ + /********Forward**************/ +/** + * @brief the fused function called by every thread + */ +template +__forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread( + const int row_id, const int col_id, const int cols, + curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, + const T *src, const T *residual, const T *bias, T *dst, MaskType *mask, + U *mean_val, U *var_val) { + using LoadT = AlignedVector; + using MaskLoadT = AlignedVector; + T src_vec[VecSize]; + T residual_vec[VecSize]; + T bias_vec[VecSize]; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + bias_vec[ii] = static_cast(0); + } + // vectorize load data from global + LoadT *value = reinterpret_cast(&src_vec); + LoadT *residual_value = reinterpret_cast(&residual_vec); + *value = *reinterpret_cast(&src[row_id * cols + col_id]); + *residual_value = + *reinterpret_cast(&residual[row_id * cols + col_id]); + + LoadT *bias_value = + bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; + if (bias != nullptr) + *bias_value = *reinterpret_cast(&bias[col_id]); + + float4 rand = curand_uniform4(state); + T dest_vec[VecSize]; + MaskType mask_vec[VecSize]; + +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = (MaskType)((&rand.x)[ii] >= dropout_prob); + } + +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dest_vec[ii] = + (src_vec[ii] + bias_vec[ii]) * static_cast(mask_vec[ii]) * factor + + residual_vec[ii]; + if (layer_norm) { + U tmp = static_cast(dest_vec[ii]); + *mean_val += tmp; + *var_val += (tmp * tmp); + } + } + + // store result to global + *(reinterpret_cast(&dst[row_id * cols + col_id])) = + *reinterpret_cast(&dest_vec[0]); + *(reinterpret_cast(&mask[row_id * cols + col_id])) = + *reinterpret_cast(&mask_vec[0]); +} + /** * @brief dst = residual + dropout(src + bias); * the src, residual, mask and dst shape is (rows, cols) @@ -46,67 +111,71 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, if (!is_upscale_in_train) { factor = static_cast(1.0f); } - using LoadT = AlignedVector; - using MaskLoadT = AlignedVector; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - T src_vec[VecSize]; - T residual_vec[VecSize]; - T bias_vec[VecSize]; -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - bias_vec[ii] = static_cast(0); - } - // vectorize load data from global - LoadT *value = reinterpret_cast(&src_vec); - LoadT *residual_value = reinterpret_cast(&residual_vec); - *value = *reinterpret_cast(&src[r * cols + i]); - *residual_value = - *reinterpret_cast(&residual[r * cols + i]); - - LoadT *bias_value = - bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; - if (bias != nullptr) - *bias_value = *reinterpret_cast(&bias[i]); - - float4 rand = curand_uniform4(&state); - T dest_vec[VecSize]; - MaskType mask_vec[VecSize]; + FusedResidualDropoutBiasVecOneThread( + r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, + mask, NULL, NULL); + } + } +} +/** + * @brief the fused function called by every thread + */ +template +__forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferVecOneThread( + const int row_id, const int col_id, const int cols, + const float dropout_prob, const T factor, const T *src, const T *residual, + const T *bias, T *dst, U *mean_val, U *var_val) { + using LoadT = AlignedVector; + T src_vec[VecSize]; + T residual_vec[VecSize]; + T bias_vec[VecSize]; #pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - mask_vec[ii] = (MaskType)((&rand.x)[ii] >= dropout_prob); - } + for (int ii = 0; ii < VecSize; ii++) { + bias_vec[ii] = static_cast(0); + } + // vectorize load data from global + LoadT *value = reinterpret_cast(&src_vec); + LoadT *residual_value = reinterpret_cast(&residual_vec); + *value = *reinterpret_cast(&src[row_id * cols + col_id]); + *residual_value = + *reinterpret_cast(&residual[row_id * cols + col_id]); -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * - static_cast(mask_vec[ii]) * factor + - residual_vec[ii]; - } + LoadT *bias_value = + bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; + if (bias != nullptr) + *bias_value = *reinterpret_cast(&bias[col_id]); + + T dest_vec[VecSize]; - // store result to global - *(reinterpret_cast(&dst[r * cols + i])) = - *reinterpret_cast(&dest_vec[0]); - *(reinterpret_cast(&mask[r * cols + i])) = - *reinterpret_cast(&mask_vec[0]); +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii]; + if (layer_norm) { + U tmp = static_cast(dest_vec[ii]); + *mean_val += tmp; + *var_val += (tmp * tmp); } } + + // store result to global + *(reinterpret_cast(&dst[row_id * cols + col_id])) = + *reinterpret_cast(&dest_vec[0]); } /** - * @brief for dropout's param is_test = true + * @brief for dropout's param is_test = true, only used in inference * the src, residual and dst shape is (rows, cols) * the bias shape is (1, cols) */ template -__global__ void FusedResidualDropoutBiasIsTest(const size_t rows, - const size_t cols, - const float dropout_prob, - const bool is_upscale_in_train, - const T *src, const T *residual, - const T *bias, T *dst) { +__global__ void FusedResidualDropoutBiasOnlyInferVec( + const size_t rows, const size_t cols, const float dropout_prob, + const bool is_upscale_in_train, const T *src, const T *residual, + const T *bias, T *dst) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; @@ -116,39 +185,12 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows, factor = static_cast(1.0f); } - using LoadT = AlignedVector; - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - T src_vec[VecSize]; - T residual_vec[VecSize]; - T bias_vec[VecSize]; -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - bias_vec[ii] = static_cast(0); - } - // vectorize load data from global - LoadT *value = reinterpret_cast(&src_vec); - LoadT *residual_value = reinterpret_cast(&residual_vec); - *value = *reinterpret_cast(&src[r * cols + i]); - *residual_value = - *reinterpret_cast(&residual[r * cols + i]); - - LoadT *bias_value = - bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; - if (bias != nullptr) - *bias_value = *reinterpret_cast(&bias[i]); - - T dest_vec[VecSize]; -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii]; - } - - // store result to global - *(reinterpret_cast(&dst[r * cols + i])) = - *reinterpret_cast(&dest_vec[0]); + FusedResidualDropoutBiasOnlyInferVecOneThread( + r, i, cols, dropout_prob, factor, src, residual, bias, dst, nullptr, + nullptr); } } } @@ -159,7 +201,7 @@ __global__ void FusedResidualDropoutBiasIsTest(const size_t rows, template void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, const int increment, uint64_t seed, - const float dropout_prob, + const float dropout_prob, const bool is_test, bool is_upscale_in_train, const T *src, const T *residual, const T *bias, MaskType *mask_data, T *dst, @@ -176,46 +218,32 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, const int VecSize = 4; auto threads = Get1DBlocksAnd2DGrids(ctx, rows, cols); - if (cols % VecSize != 0) - FusedResidualDropoutBiasVec< - T, uint8_t, 1><<>>( - rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, - bias, mask_data, dst, increment); - else - FusedResidualDropoutBiasVec< - T, uint8_t, - VecSize><<>>( - rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, - bias, mask_data, dst, increment); -} - -/** - *@brief to launch kernel FusedResidualDropoutBiasIsTest - */ -template -void LaunchResidualDropoutBiasIsTest(const uint32_t rows, const uint32_t cols, - const float dropout_prob, - bool is_upscale_in_train, const T *src, - const T *residual, const T *bias, T *dst, - const platform::CUDADeviceContext &ctx) { - if (std::abs(dropout_prob - 1.0f) < 1e-5) { - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpyAsync(dst, residual, rows * cols * sizeof(T), - cudaMemcpyDeviceToDevice, ctx.stream())); - return; + if (cols % VecSize != 0) { + if (!is_test) { + FusedResidualDropoutBiasVec< + T, uint8_t, 1><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment); + } else { + FusedResidualDropoutBiasOnlyInferVec< + T, 1><<>>( + rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, + dst); + } + } else { + if (!is_test) { + FusedResidualDropoutBiasVec< + T, uint8_t, + VecSize><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment); + } else { + FusedResidualDropoutBiasOnlyInferVec< + T, VecSize><<>>( + rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, + dst); + } } - const int VecSize = 4; - auto threads = Get1DBlocksAnd2DGrids(ctx, rows, cols); - if (cols % VecSize != 0) - FusedResidualDropoutBiasIsTest< - T, 1><<>>( - rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, - dst); - else - FusedResidualDropoutBiasIsTest< - T, VecSize><<>>( - rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, - dst); } /********Backward**************/ diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index fa119f1132e8f6..d5377194934ff6 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -43,7 +43,7 @@ struct TestFusedResidualDropoutBias { framework::Tensor _src, _residual, _bias, _out, _mask; framework::Tensor _dsrc, _dbias; - std::vector _src_vec, _residual_vec, _bias_vec, _out_vec, _mask_vec; + std::vector _src_vec, _residual_vec, _bias_vec; std::vector _correct_out, _correct_dsrc, _correct_dbias; std::vector _correct_mask; @@ -173,16 +173,10 @@ struct TestFusedResidualDropoutBias { if (_has_bias) { bias_ptr = _bias.data(); } - if (_is_test) { - paddle::operators::LaunchResidualDropoutBiasIsTest( - _rows, _cols, _dropout_prob, _is_upscale_in_train, _src.data(), - _residual.data(), bias_ptr, _out.data(), *_ctx); - } else { - paddle::operators::LaunchResidualDropoutBias( - _rows, _cols, increment, _seed, _dropout_prob, _is_upscale_in_train, - _src.data(), _residual.data(), bias_ptr, _mask.data(), - _out.data(), *_ctx); - } + paddle::operators::LaunchResidualDropoutBias( + _rows, _cols, increment, _seed, _dropout_prob, _is_test, + _is_upscale_in_train, _src.data(), _residual.data(), bias_ptr, + _mask.data(), _out.data(), *_ctx); _ctx->Wait(); } @@ -249,7 +243,7 @@ struct TestFusedResidualDropoutBias { } }; -TEST(FusedDropout, GPUFusedRedisualDorpoutBias) { +TEST(FusedDropout, GPUFusedResidualDropoutBias) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols); @@ -258,7 +252,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBias) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasDouble) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols); @@ -268,7 +262,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBiasDouble) { } // test fp16, For inference, check_grad is not required. ref: test_dropout_op.py -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasFp16) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols); @@ -277,7 +271,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBiasFp16) { } // test no bias and cols % 4 == 0 -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasNoBias) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasNoBias) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols); @@ -288,7 +282,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBiasNoBias) { } // test no bias and cols % 4 != 0 -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasNoBias2) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasNoBias2) { const int rows = 16; const int cols = 17; TestFusedResidualDropoutBias test(rows, cols); @@ -299,7 +293,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBiasNoBias2) { } // test add bias and cols % 4 != 0 -TEST(FusedDropout, GPUFusedRedisualDorpoutBias2) { +TEST(FusedDropout, GPUFusedResidualDropoutBias2) { const int rows = 16; const int cols = 17; TestFusedResidualDropoutBias test(rows, cols); @@ -308,7 +302,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBias2) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBias3) { +TEST(FusedDropout, GPUFusedResidualDropoutBias3) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); @@ -317,7 +311,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBias3) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBias4) { +TEST(FusedDropout, GPUFusedResidualDropoutBias4) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); @@ -326,7 +320,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBias4) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBias5) { +TEST(FusedDropout, GPUFusedResidualDropoutBias5) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, true, false); @@ -335,7 +329,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBias5) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBias6) { +TEST(FusedDropout, GPUFusedResidualDropoutBias6) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 0.35, true, true); @@ -344,7 +338,7 @@ TEST(FusedDropout, GPUFusedRedisualDorpoutBias6) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBias7) { +TEST(FusedDropout, GPUFusedResidualDropoutBias7) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 125, 0.0, false, false); From 4d33b98f2fc4468f87e2462cfb1e1248e790a93d Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Mon, 30 Aug 2021 10:30:04 +0000 Subject: [PATCH 06/18] modify the code according to the review comments --- paddle/fluid/operators/fused/CMakeLists.txt | 2 +- ...fused_dropout.h => fused_dropout_common.h} | 5 +- .../operators/fused/fused_dropout_test.h | 22 +- .../fused/fused_residual_dropout_bias.h | 42 +-- .../fused/fused_residual_dropout_bias_test.cu | 262 +++++++++--------- 5 files changed, 171 insertions(+), 162 deletions(-) rename paddle/fluid/operators/fused/{fused_dropout.h => fused_dropout_common.h} (95%) diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index f3035cddcba020..3df2144aa35944 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -74,6 +74,6 @@ if (WITH_GPU OR WITH_ROCM) # fused_dropout # only support CUDA if(NOT WITH_ROCM) - nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator) + nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) endif() endif() diff --git a/paddle/fluid/operators/fused/fused_dropout.h b/paddle/fluid/operators/fused/fused_dropout_common.h similarity index 95% rename from paddle/fluid/operators/fused/fused_dropout.h rename to paddle/fluid/operators/fused/fused_dropout_common.h index bd6a4122f5830d..755153bb07eee9 100644 --- a/paddle/fluid/operators/fused/fused_dropout.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -18,9 +18,8 @@ limitations under the License. */ #include #include -#include -#include - +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index 6cb8cd19b608d1..4a5e088d2013b8 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -22,11 +22,13 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/string/printf.h" namespace framework = paddle::framework; namespace platform = paddle::platform; +namespace memory = paddle::memory; USE_OP(dropout); @@ -34,17 +36,15 @@ USE_OP(dropout); * @brief call paddle dropout op */ template -void Dropout(const T *x, const framework::DDim &x_dim, T *out, - std::vector *mask, const platform::CUDADeviceContext &ctx, - uint64_t seed, float dropout_prob, bool is_upscale_in_train, - bool is_test) { +void Dropout(const std::vector &x, const framework::DDim &x_dim, + std::vector *out, std::vector *mask, + const platform::CUDADeviceContext &ctx, uint64_t seed, + float dropout_prob, bool is_upscale_in_train, bool is_test) { framework::Scope scope; auto var_x = scope.Var("X"); auto tensor_x = var_x->GetMutable(); + framework::TensorFromVector(x, ctx, tensor_x); tensor_x->Resize(x_dim); - tensor_x->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_x->data(), x, x_dim[0] * x_dim[1] * sizeof(T), - cudaMemcpyHostToDevice); auto var_out = scope.Var("Out"); auto tensor_out = var_out->GetMutable(); @@ -59,6 +59,7 @@ void Dropout(const T *x, const framework::DDim &x_dim, T *out, if (is_upscale_in_train) { attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); } + if (is_test) { attrs.insert({"is_test", 1}); } @@ -66,11 +67,10 @@ void Dropout(const T *x, const framework::DDim &x_dim, T *out, auto op = framework::OpRegistry::CreateOp( "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs); op->Run(scope, ctx.GetPlace()); - cudaMemcpy(out, tensor_out->data(), x_dim[0] * x_dim[1] * sizeof(T), - cudaMemcpyDeviceToHost); + + framework::TensorToVector(*tensor_out, ctx, out); if (!is_test) { - cudaMemcpy((*mask).data(), tensor_mask->data(), - x_dim[0] * x_dim[1] * sizeof(uint8_t), cudaMemcpyDeviceToHost); + framework::TensorToVector(*tensor_mask, ctx, mask); } ctx.Wait(); } diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 0a263635e46029..eda633380e07a2 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/fused/fused_dropout.h" +#include "paddle/fluid/operators/fused/fused_dropout_common.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" namespace paddle { @@ -22,6 +22,7 @@ namespace operators { namespace platform = paddle::platform; namespace cg = cooperative_groups; +namespace memory = paddle::memory; /** * @brief fused the add_bias, dropout, add residual into one operators @@ -32,15 +33,17 @@ namespace cg = cooperative_groups; /** * @brief the fused function called by every thread */ -template +template __forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread( const int row_id, const int col_id, const int cols, curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, const T *src, const T *residual, const T *bias, T *dst, MaskType *mask, - U *mean_val, U *var_val) { + typename details::MPTypeTrait::Type *mean_val, + typename details::MPTypeTrait::Type *var_val) { using LoadT = AlignedVector; using MaskLoadT = AlignedVector; + using U = typename details::MPTypeTrait::Type; + T src_vec[VecSize]; T residual_vec[VecSize]; T bias_vec[VecSize]; @@ -74,7 +77,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread( dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * static_cast(mask_vec[ii]) * factor + residual_vec[ii]; - if (layer_norm) { + if (ComputeLayerNorm) { U tmp = static_cast(dest_vec[ii]); *mean_val += tmp; *var_val += (tmp * tmp); @@ -114,7 +117,7 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - FusedResidualDropoutBiasVecOneThread( + FusedResidualDropoutBiasVecOneThread( r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, mask, NULL, NULL); } @@ -208,9 +211,10 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, const platform::CUDADeviceContext &ctx) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpyAsync(dst, residual, rows * cols * sizeof(T), - cudaMemcpyDeviceToDevice, ctx.stream())); + if (residual == dst) return; + auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), + ctx.stream()); PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); return; @@ -282,7 +286,8 @@ __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, * 2. save 128*8 temporary sum in 8*128 shared memory * 3. reduce the sum of 128 rows data by 8*VecSize warps */ -template +template __global__ void FusedResidualDropoutBiasGradVec( const T *dout, const MaskType *mask, const T factor, const int64_t rows, const int64_t cols, T *dx, T *dbias) { @@ -316,9 +321,10 @@ __global__ void FusedResidualDropoutBiasGradVec( } // save temporary sum to cache and do transpose - __shared__ T cache[BSX * VecSize][BSY]; - for (int i = 0; i < VecSize; i++) + __shared__ T cache[BLOCK_SIZE_X * VecSize][BLOCK_SIZE_Y]; + for (int i = 0; i < VecSize; i++) { cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; + } __syncthreads(); // reduce sum @@ -327,11 +333,11 @@ __global__ void FusedResidualDropoutBiasGradVec( int x = tid >> 5; // warp id int y = tid & 31; // thread id on warp 0~31 - // need BSX * VecSize warps - if (x < BSX * VecSize) { + // need BLOCK_SIZE_X * VecSize warps + if (x < BLOCK_SIZE_X * VecSize) { // reduce 128 to 32 #pragma unroll - for (int i = 0; i < (BSY >> 5); i++) { + for (int i = 0; i < (BLOCK_SIZE_Y >> 5); i++) { sum += cache[x][y + i * 32]; } } @@ -341,7 +347,7 @@ __global__ void FusedResidualDropoutBiasGradVec( // save sum to dbias int bias_id = blockIdx.x * blockDim.x * VecSize + x; - if (y == 0 && x < VecSize * BSX && bias_id < cols) { + if (y == 0 && x < VecSize * BLOCK_SIZE_X && bias_id < cols) { dbias[bias_id] = sum; } } @@ -367,7 +373,9 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, const int VecSize = 4; if (dbias != nullptr) { int real_vec_size = VecSize; - if (cols % VecSize != 0) real_vec_size = 1; + if (cols % VecSize != 0) { + real_vec_size = 1; + } auto threads = std::min(cols / real_vec_size, static_cast(8)); auto blocks = std::max( (uint32_t)1, std::min((cols / real_vec_size + threads - 1) / threads, diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index d5377194934ff6..88438e6e0c36e5 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -24,7 +24,7 @@ namespace framework = paddle::framework; namespace platform = paddle::platform; /** - * @brief the unittest of fused_residual_dropout_bias + * @brief the unittest of fusedresidualdropoutbias * 1. random input data * 2. add bias, call paddle dropout op, add residual, and get the base result * 3. call FusedResidualDropoutBias function get fused result @@ -33,163 +33,169 @@ namespace platform = paddle::platform; template struct TestFusedResidualDropoutBias { - uint32_t _rows; - uint32_t _cols; - uint64_t _seed; - float _dropout_prob; - bool _is_upscale_in_train; - bool _is_test; // default false, Set to true for inference only - bool _has_bias = true; - framework::Tensor _src, _residual, _bias, _out, _mask; - framework::Tensor _dsrc, _dbias; - - std::vector _src_vec, _residual_vec, _bias_vec; - std::vector _correct_out, _correct_dsrc, _correct_dbias; - std::vector _correct_mask; - - platform::CUDAPlace _place; - platform::CUDADeviceContext *_ctx; + uint32_t rows; + uint32_t cols; + uint64_t seed; + float dropout_prob; + bool is_upscale_in_train; + bool is_test; // default false, Set to true for inference only + bool hasbias = true; + framework::Tensor src, residual, bias, out, mask; + framework::Tensor dsrc, dbias; + + std::vector src_vec, residual_vec, bias_vec; + std::vector correct_out, correct_dsrc, correct_dbias; + std::vector correct_mask; + + platform::CUDAPlace place; + platform::CUDADeviceContext *ctx; TestFusedResidualDropoutBias() { - _rows = 32; - _cols = 32; - _seed = 0; - _dropout_prob = 0.0; - _is_upscale_in_train = false; - _is_test = false; - _has_bias = true; - _ctx = new platform::CUDADeviceContext(_place); + rows = 32; + cols = 32; + seed = 0; + dropout_prob = 0.0; + is_upscale_in_train = false; + is_test = false; + hasbias = true; + // ctx = new platform::CUDADeviceContext(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); } - TestFusedResidualDropoutBias(int rows, int cols, uint64_t seed = 0, - float dropout_prob = 0.0, - bool is_upscale_in_train = false, - bool is_test = false) { - _rows = rows; - _cols = cols; - _seed = seed; - _dropout_prob = dropout_prob; - _is_upscale_in_train = is_upscale_in_train; - _is_test = is_test; - _has_bias = true; - _ctx = new platform::CUDADeviceContext(_place); + TestFusedResidualDropoutBias(int rows_, int cols_, uint64_t seed_ = 0, + float dropout_prob_ = 0.0, + bool is_upscale_in_train_ = false, + bool is_test_ = false) { + rows = rows_; + cols = cols_; + seed = seed_; + dropout_prob = dropout_prob_; + is_upscale_in_train = is_upscale_in_train_; + is_test = is_test_; + hasbias = true; + // ctx = new platform::CUDADeviceContext(place); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); } - ~TestFusedResidualDropoutBias() { delete _ctx; } + ~TestFusedResidualDropoutBias() {} void SetUp() { - const int n = _rows * _cols; - _correct_out.resize(n); - _correct_mask.resize(n); - _correct_dsrc.resize(n); - _correct_dbias.resize(_cols); - - _src_vec.resize(n); - _residual_vec.resize(n); - _bias_vec.resize(_cols); + const int n = rows * cols; + correct_out.resize(n); + correct_mask.resize(n); + correct_dsrc.resize(n); + correct_dbias.resize(cols); + + src_vec.resize(n); + residual_vec.resize(n); + bias_vec.resize(cols); std::default_random_engine random(time(NULL)); std::uniform_real_distribution dis(0.0, 1.0); - for (int i = 0; i < _rows; i++) { - for (int j = 0; j < _cols; j++) { - _src_vec[i * _cols + j] = static_cast(dis(random)); - _residual_vec[i * _cols + j] = static_cast(dis(random)); - if (i == 0) _bias_vec[j] = dis(random); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + src_vec[i * cols + j] = static_cast(dis(random)); + residual_vec[i * cols + j] = static_cast(dis(random)); + if (i == 0) bias_vec[j] = dis(random); } } - framework::TensorFromVector(_src_vec, *_ctx, &_src); - _src.Resize({_rows, _cols}); - framework::TensorFromVector(_residual_vec, *_ctx, &_residual); - _residual.Resize({_rows, _cols}); - if (_has_bias) { - framework::TensorFromVector(_bias_vec, *_ctx, &_bias); - _bias.Resize({_cols}); + framework::TensorFromVector(src_vec, *ctx, &src); + src.Resize({rows, cols}); + framework::TensorFromVector(residual_vec, *ctx, &residual); + residual.Resize({rows, cols}); + if (hasbias) { + framework::TensorFromVector(bias_vec, *ctx, &bias); + bias.Resize({cols}); } { - _out.Resize({_rows, _cols}); - _out.mutable_data(_place); - _mask.Resize({_rows, _cols}); - _mask.mutable_data(_place); - _dsrc.Resize({_rows, _cols}); - _dsrc.mutable_data(_place); - - if (_has_bias) { - _dbias.Resize({_cols}); - _dbias.mutable_data(_place); + out.Resize({rows, cols}); + out.mutable_data(place); + mask.Resize({rows, cols}); + mask.mutable_data(place); + dsrc.Resize({rows, cols}); + dsrc.mutable_data(place); + + if (hasbias) { + dbias.Resize({cols}); + dbias.mutable_data(place); } } } void BaseForward() { - std::vector out1(_rows * _cols), out2(_rows * _cols); - if (_has_bias) { + std::vector out1(rows * cols), out2(rows * cols); + if (hasbias) { // add bias - for (int i = 0; i < _rows; i++) { - for (int j = 0; j < _cols; j++) { - out1[i * _cols + j] = _src_vec[i * _cols + j] + _bias_vec[j]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + out1[i * cols + j] = src_vec[i * cols + j] + bias_vec[j]; } } // call dropout - Dropout(out1.data(), _src.dims(), out2.data(), &_correct_mask, *_ctx, - _seed, _dropout_prob, _is_upscale_in_train, _is_test); + Dropout(out1, src.dims(), &out2, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); } else { - Dropout(_src_vec.data(), _src.dims(), out2.data(), &_correct_mask, - *_ctx, _seed, _dropout_prob, _is_upscale_in_train, _is_test); + Dropout(src_vec, src.dims(), &out2, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); } + ctx->Wait(); // add residual - for (int i = 0; i < _rows; i++) { - for (int j = 0; j < _cols; j++) { - _correct_out[i * _cols + j] = - _residual_vec[i * _cols + j] + out2[i * _cols + j]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + correct_out[i * cols + j] = + residual_vec[i * cols + j] + out2[i * cols + j]; } } - _ctx->Wait(); } void BaseBackward() { - DropoutGrad(_correct_dsrc.data(), _src.dims(), _correct_out.data(), - _correct_mask.data(), *_ctx, _dropout_prob, - _is_upscale_in_train); + DropoutGrad(correct_dsrc.data(), src.dims(), correct_out.data(), + correct_mask.data(), *ctx, dropout_prob, + is_upscale_in_train); // calc dbias - memset(&_correct_dbias[0], 0, _cols * sizeof(T)); - for (int i = 0; i < _rows; i++) { - for (int j = 0; j < _cols; j++) { - _correct_dbias[j] += _correct_out[i * _cols + j]; + memset(&correct_dbias[0], 0, cols * sizeof(T)); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + correct_dbias[j] += correct_out[i * cols + j]; } } } void FusedForward() { auto threads = paddle::operators::Get1DBlocksAnd2DGrids( - *_ctx, (uint64_t)_rows, (uint64_t)_cols); + *ctx, (uint64_t)rows, (uint64_t)cols); const int VecSize = 4; const int increment = - ((_cols - 1) / (threads.first.x * threads.second.x * VecSize) + 1) * + ((cols - 1) / (threads.first.x * threads.second.x * VecSize) + 1) * VecSize; T *bias_ptr = nullptr; - if (_has_bias) { - bias_ptr = _bias.data(); + if (hasbias) { + bias_ptr = bias.data(); } paddle::operators::LaunchResidualDropoutBias( - _rows, _cols, increment, _seed, _dropout_prob, _is_test, - _is_upscale_in_train, _src.data(), _residual.data(), bias_ptr, - _mask.data(), _out.data(), *_ctx); - _ctx->Wait(); + rows, cols, increment, seed, dropout_prob, is_test, is_upscale_in_train, + src.data(), residual.data(), bias_ptr, mask.data(), + out.data(), *ctx); + ctx->Wait(); } void FusedBackward() { - if (_is_test) return; + if (is_test) return; T *bias_ptr = nullptr; - if (_has_bias) { - bias_ptr = _dbias.data(); + if (hasbias) { + bias_ptr = dbias.data(); } paddle::operators::LaunchResidualDropoutBiasGrad( - _out.data(), _mask.data(), _dropout_prob, - _is_upscale_in_train, _rows, _cols, _dsrc.data(), bias_ptr, *_ctx); + out.data(), mask.data(), dropout_prob, is_upscale_in_train, + rows, cols, dsrc.data(), bias_ptr, *ctx); } void Run() { @@ -201,43 +207,39 @@ struct TestFusedResidualDropoutBias { } void CheckOut(const T diff) { - const int n = _rows * _cols; - std::vector out(n); - std::vector mask(n); - cudaMemcpy(out.data(), _out.data(), _rows * _cols * sizeof(T), - cudaMemcpyDeviceToHost); - if (!_is_test) { - cudaMemcpy(mask.data(), _mask.data(), - _rows * _cols * sizeof(uint8_t), cudaMemcpyDeviceToHost); + const int n = rows * cols; + std::vector _out(n); + std::vector _mask(n); + framework::TensorToVector(out, *ctx, &_out); + if (!is_test) { + framework::TensorToVector(mask, *ctx, &_mask); } - _ctx->Wait(); + ctx->Wait(); for (int i = 0; i < n; i++) { - EXPECT_LT(std::abs(out[i] - _correct_out[i]), diff); - if (!_is_test) EXPECT_EQ(mask[i], _correct_mask[i]); + EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); + if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); } } void CheckGrad(const T diff) { - if (_is_test) return; + if (is_test) return; - const int n = _rows * _cols; + const int n = rows * cols; - std::vector dsrc(n); - cudaMemcpy(dsrc.data(), _dsrc.data(), _rows * _cols * sizeof(T), - cudaMemcpyDeviceToHost); + std::vector _dsrc(n); + framework::TensorToVector(dsrc, *ctx, &_dsrc); for (int i = 0; i < n; i++) { - EXPECT_LT(std::abs(dsrc[i] - _correct_dsrc[i]), diff); + EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff); } - if (_has_bias) { - std::vector dbias(_cols); - cudaMemcpy(dbias.data(), _dbias.data(), _cols * sizeof(T), - cudaMemcpyDeviceToHost); - _ctx->Wait(); - for (int i = 0; i < _cols; i++) { - EXPECT_LT(std::abs(dbias[i] - _correct_dbias[i]), diff); + if (hasbias) { + std::vector _dbias(cols); + framework::TensorToVector(dbias, *ctx, &_dbias); + ctx->Wait(); + for (int i = 0; i < cols; i++) { + EXPECT_LT(std::abs(_dbias[i] - correct_dbias[i]), diff); } } } @@ -261,7 +263,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { test.CheckGrad(static_cast(1e-5)); } -// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py +// test fp16, For inference, check_grad is not required. ref: testdropout_op.py TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { const int rows = 16; const int cols = 16; @@ -275,7 +277,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasNoBias) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols); - test._has_bias = false; + test.hasbias = false; test.Run(); test.CheckOut(static_cast(1e-5)); test.CheckGrad(static_cast(1e-5)); @@ -286,7 +288,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBiasNoBias2) { const int rows = 16; const int cols = 17; TestFusedResidualDropoutBias test(rows, cols); - test._has_bias = false; + test.hasbias = false; test.Run(); test.CheckOut(static_cast(1e-5)); test.CheckGrad(static_cast(1e-5)); From bd44d043d24b5f76c2f07da6bb0a7a52788c1ace Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Mon, 30 Aug 2021 11:27:40 +0000 Subject: [PATCH 07/18] replace cudaMemcpy with TensorFromVector and TensorToVector in DropoutGrad --- .../operators/fused/fused_dropout_test.h | 18 +++++++----------- .../fused/fused_residual_dropout_bias_test.cu | 19 +++++++++++-------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index 4a5e088d2013b8..e9fd0e6c097851 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -79,24 +79,21 @@ void Dropout(const std::vector &x, const framework::DDim &x_dim, * @brief call paddle dropout_grad op */ template -void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout, - const uint8_t *mask, const platform::CUDADeviceContext &ctx, - float dropout_prob, bool is_upscale_in_train) { +void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, + const std::vector &dout, const std::vector &mask, + const platform::CUDADeviceContext &ctx, float dropout_prob, + bool is_upscale_in_train) { framework::Scope scope; const size_t n = x_dim[0] * x_dim[1]; auto var_out = scope.Var("DOut"); auto tensor_out = var_out->GetMutable(); + framework::TensorFromVector(dout, ctx, tensor_out); tensor_out->Resize(x_dim); - tensor_out->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_out->data(), dout, n * sizeof(T), - cudaMemcpyHostToDevice); auto var_mask = scope.Var("Mask"); auto tensor_mask = var_mask->GetMutable(); + framework::TensorFromVector(mask, ctx, tensor_mask); tensor_mask->Resize(x_dim); - tensor_mask->mutable_data(ctx.GetPlace()); - cudaMemcpy(tensor_mask->data(), mask, n * sizeof(uint8_t), - cudaMemcpyHostToDevice); auto var_dx = scope.Var("DX"); auto tensor_dx = var_dx->GetMutable(); @@ -115,7 +112,6 @@ void DropoutGrad(T *dx, const framework::DDim &x_dim, const T *dout, {{"X@GRAD", {"DX"}}}, attrs); op->Run(scope, ctx.GetPlace()); - cudaMemcpy(dx, tensor_dx->data(), x_dim[0] * x_dim[1] * sizeof(T), - cudaMemcpyDeviceToHost); + framework::TensorToVector(*tensor_dx, ctx, dx); ctx.Wait(); } diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index 88438e6e0c36e5..14267974ff2657 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -58,7 +58,6 @@ struct TestFusedResidualDropoutBias { is_upscale_in_train = false; is_test = false; hasbias = true; - // ctx = new platform::CUDADeviceContext(place); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto devicectx = pool.Get(place); ctx = reinterpret_cast(devicectx); @@ -75,7 +74,6 @@ struct TestFusedResidualDropoutBias { is_upscale_in_train = is_upscale_in_train_; is_test = is_test_; hasbias = true; - // ctx = new platform::CUDADeviceContext(place); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto devicectx = pool.Get(place); ctx = reinterpret_cast(devicectx); @@ -100,7 +98,9 @@ struct TestFusedResidualDropoutBias { for (int j = 0; j < cols; j++) { src_vec[i * cols + j] = static_cast(dis(random)); residual_vec[i * cols + j] = static_cast(dis(random)); - if (i == 0) bias_vec[j] = dis(random); + if (i == 0) { + bias_vec[j] = dis(random); + } } } @@ -155,9 +155,8 @@ struct TestFusedResidualDropoutBias { } void BaseBackward() { - DropoutGrad(correct_dsrc.data(), src.dims(), correct_out.data(), - correct_mask.data(), *ctx, dropout_prob, - is_upscale_in_train); + DropoutGrad(&correct_dsrc, src.dims(), correct_out, correct_mask, *ctx, + dropout_prob, is_upscale_in_train); // calc dbias memset(&correct_dbias[0], 0, cols * sizeof(T)); for (int i = 0; i < rows; i++) { @@ -187,7 +186,9 @@ struct TestFusedResidualDropoutBias { } void FusedBackward() { - if (is_test) return; + if (is_test) { + return; + } T *bias_ptr = nullptr; if (hasbias) { @@ -223,7 +224,9 @@ struct TestFusedResidualDropoutBias { } void CheckGrad(const T diff) { - if (is_test) return; + if (is_test) { + return; + } const int n = rows * cols; From d2beab70c620a19a6e26b5e8c16a4f537882fc68 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Tue, 31 Aug 2021 02:19:33 +0000 Subject: [PATCH 08/18] set dropout attr 'is_test':false --- paddle/fluid/operators/fused/fused_dropout_test.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index e9fd0e6c097851..288b415aef31f9 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -61,7 +61,7 @@ void Dropout(const std::vector &x, const framework::DDim &x_dim, } if (is_test) { - attrs.insert({"is_test", 1}); + attrs.insert({"is_test", true}); } auto op = framework::OpRegistry::CreateOp( @@ -100,7 +100,7 @@ void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, framework::AttributeMap attrs; attrs.insert({"dropout_prob", dropout_prob}); - attrs.insert({"is_test", 0}); + attrs.insert({"is_test", false}); if (is_upscale_in_train) { attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); } else { From 5d2bbc889b8d3774fa01003f325fe012da2a11c8 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Thu, 2 Sep 2021 09:22:14 +0000 Subject: [PATCH 09/18] optimize the code according to the review comments --- .../operators/fused/fused_dropout_common.h | 73 ++++++--- .../fused/fused_residual_dropout_bias.h | 142 +++++++++--------- .../fused/fused_residual_dropout_bias_test.cu | 129 ++++++---------- 3 files changed, 167 insertions(+), 177 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 755153bb07eee9..53ce76826a6793 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -22,48 +22,75 @@ limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { -/** - * get 1D threads and blocks - */ -template -inline std::pair Get1DThreadsAndBlocks( - const platform::CUDADeviceContext &ctx, const uint64_t n) { - const uint64_t tmp_n = n / VecSize; - int threads = std::max( - (uint64_t)32, std::min(tmp_n, (uint64_t)ctx.GetMaxThreadsPerBlock())); - int blocks = std::max((uint64_t)1, (tmp_n + threads - 1) / threads); - return std::pair{threads, blocks}; -} +#define MAX_CACHE_BYTES 16 /** * get the threads for fused_residual_dropout_bias: * 1D blocks: blockDim.x = cols * 2D grids: gridDim.y = rows */ -template -inline std::pair Get1DBlocksAnd2DGrids( +inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( const platform::CUDADeviceContext &ctx, const uint32_t rows, - const uint32_t cols) { + const uint32_t cols, const int VecSize) { const uint32_t tmp_cols = cols / VecSize; int threads = std::max( (uint32_t)32, std::min(tmp_cols, (uint32_t)ctx.GetMaxThreadsPerBlock())); int blocks_x = std::max((uint32_t)1, (tmp_cols + threads - 1) / threads); int blocks_y = std::max((uint32_t)1, rows); - dim3 block_dim(threads, 1, 1); - dim3 grid_dim(blocks_x, blocks_y, 1); - return std::pair{block_dim, grid_dim}; + platform::GpuLaunchConfig config; + config.block_per_grid.x = blocks_x; + config.block_per_grid.y = blocks_y; + config.thread_per_block.x = threads; + return config; } -// aligned vector generates vectorized load/store on CUDA -template -struct alignas(sizeof(T) * VecSize) AlignedVector { - T val[VecSize]; -}; +__forceinline__ __device__ void Rand1(curandStatePhilox4_32_10_t *state, + float *data) { + data[0] = curand_uniform(state); +} + +__forceinline__ __device__ void Rand2(curandStatePhilox4_32_10_t *state, + float *data) { + data[0] = curand_uniform(state); + data[1] = curand_uniform(state); +} + +__forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state, + float *data) { + float4 rand4 = curand_uniform4(state); + data[0] = rand4.x; + data[1] = rand4.y; + data[2] = rand4.w; + data[3] = rand4.z; +} + +__forceinline__ __device__ void Rand8(curandStatePhilox4_32_10_t *state, + float *data) { + Rand4(state, data); + Rand4(state, data + 4); +} + +__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, + float *data, const int VecSize) { + if (VecSize == 1) { + Rand1(state, data); + } else if (VecSize == 2) { + Rand2(state, data); + } else if (VecSize == 4) { + Rand4(state, data); + } else if (VecSize == 8) { + Rand8(state, data); + } else { + return; + } +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index eda633380e07a2..bafc8c60040c1e 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -20,28 +20,19 @@ limitations under the License. */ namespace paddle { namespace operators { -namespace platform = paddle::platform; -namespace cg = cooperative_groups; -namespace memory = paddle::memory; - -/** - * @brief fused the add_bias, dropout, add residual into one operators - * - */ - -/********Forward**************/ /** - * @brief the fused function called by every thread + * @brief The fused function called by every thread + * VecSize can be 1, 2, 4 or 8 */ template -__forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread( +__forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const int row_id, const int col_id, const int cols, curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, const T *src, const T *residual, const T *bias, T *dst, MaskType *mask, typename details::MPTypeTrait::Type *mean_val, typename details::MPTypeTrait::Type *var_val) { - using LoadT = AlignedVector; - using MaskLoadT = AlignedVector; + using LoadT = platform::CudaAlignedVector; + using MaskLoadT = platform::CudaAlignedVector; using U = typename details::MPTypeTrait::Type; T src_vec[VecSize]; @@ -60,16 +51,19 @@ __forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread( LoadT *bias_value = bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; - if (bias != nullptr) + if (bias) { *bias_value = *reinterpret_cast(&bias[col_id]); + } + + float rand[VecSize]; + RandVec(state, rand, VecSize); - float4 rand = curand_uniform4(state); T dest_vec[VecSize]; MaskType mask_vec[VecSize]; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { - mask_vec[ii] = (MaskType)((&rand.x)[ii] >= dropout_prob); + mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); } #pragma unroll @@ -97,13 +91,13 @@ __forceinline__ __device__ void FusedResidualDropoutBiasVecOneThread( * the bias shape is (1, cols) */ template -__global__ void FusedResidualDropoutBiasVec(const size_t rows, - const size_t cols, uint64_t seed, - const float dropout_prob, - const bool is_upscale_in_train, - const T *src, const T *residual, - const T *bias, MaskType *mask, - T *dst, uint64_t increment) { +__global__ void FusedResidualDropoutBias(const size_t rows, const size_t cols, + uint64_t seed, + const float dropout_prob, + const bool is_upscale_in_train, + const T *src, const T *residual, + const T *bias, MaskType *mask, T *dst, + uint64_t increment) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; @@ -117,9 +111,9 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - FusedResidualDropoutBiasVecOneThread( + FusedResidualDropoutBiasOneThread( r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, - mask, NULL, NULL); + mask, nullptr, nullptr); } } } @@ -127,12 +121,14 @@ __global__ void FusedResidualDropoutBiasVec(const size_t rows, /** * @brief the fused function called by every thread */ -template -__forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferVecOneThread( +template +__forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferOneThread( const int row_id, const int col_id, const int cols, const float dropout_prob, const T factor, const T *src, const T *residual, - const T *bias, T *dst, U *mean_val, U *var_val) { - using LoadT = AlignedVector; + const T *bias, T *dst, typename details::MPTypeTrait::Type *mean_val, + typename details::MPTypeTrait::Type *var_val) { + using LoadT = platform::CudaAlignedVector; + using U = typename details::MPTypeTrait::Type; T src_vec[VecSize]; T residual_vec[VecSize]; T bias_vec[VecSize]; @@ -149,15 +145,16 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferVecOneThread( LoadT *bias_value = bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; - if (bias != nullptr) + if (bias) { *bias_value = *reinterpret_cast(&bias[col_id]); + } T dest_vec[VecSize]; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii]; - if (layer_norm) { + if (ComputeLayerNorm) { U tmp = static_cast(dest_vec[ii]); *mean_val += tmp; *var_val += (tmp * tmp); @@ -175,7 +172,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferVecOneThread( * the bias shape is (1, cols) */ template -__global__ void FusedResidualDropoutBiasOnlyInferVec( +__global__ void FusedResidualDropoutBiasOnlyInfer( const size_t rows, const size_t cols, const float dropout_prob, const bool is_upscale_in_train, const T *src, const T *residual, const T *bias, T *dst) { @@ -191,7 +188,7 @@ __global__ void FusedResidualDropoutBiasOnlyInferVec( for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - FusedResidualDropoutBiasOnlyInferVecOneThread( + FusedResidualDropoutBiasOnlyInferOneThread( r, i, cols, dropout_prob, factor, src, residual, bias, dst, nullptr, nullptr); } @@ -220,48 +217,46 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, return; } - const int VecSize = 4; - auto threads = Get1DBlocksAnd2DGrids(ctx, rows, cols); + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, VecSize); if (cols % VecSize != 0) { if (!is_test) { - FusedResidualDropoutBiasVec< - T, uint8_t, 1><<>>( + FusedResidualDropoutBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, bias, mask_data, dst, increment); } else { - FusedResidualDropoutBiasOnlyInferVec< - T, 1><<>>( + FusedResidualDropoutBiasOnlyInfer<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, dst); } } else { if (!is_test) { - FusedResidualDropoutBiasVec< - T, uint8_t, - VecSize><<>>( + FusedResidualDropoutBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, bias, mask_data, dst, increment); } else { - FusedResidualDropoutBiasOnlyInferVec< - T, VecSize><<>>( + FusedResidualDropoutBiasOnlyInfer<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, dst); } } } -/********Backward**************/ /* * @brief calculate the grad of no bias */ template -__global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, - const T factor, const int64_t size, - T *dx) { +__global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask, + const T factor, const int64_t size, + T *dx) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - using LoadT = AlignedVector; - using MaskLoadT = AlignedVector; + using LoadT = platform::CudaAlignedVector; + using MaskLoadT = platform::CudaAlignedVector; for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { T dout_vec[VecSize]; MaskType mask_vec[VecSize]; @@ -286,15 +281,17 @@ __global__ void FusedResidualDropoutGradVec(const T *dout, const MaskType *mask, * 2. save 128*8 temporary sum in 8*128 shared memory * 3. reduce the sum of 128 rows data by 8*VecSize warps */ -template -__global__ void FusedResidualDropoutBiasGradVec( - const T *dout, const MaskType *mask, const T factor, const int64_t rows, - const int64_t cols, T *dx, T *dbias) { +__global__ void FusedResidualDropoutBiasGrad(const T *dout, + const MaskType *mask, + const T factor, const int64_t rows, + const int64_t cols, T *dx, + T *dbias) { int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; - using LoadT = AlignedVector; - using MaskLoadT = AlignedVector; + using LoadT = platform::CudaAlignedVector; + using MaskLoadT = platform::CudaAlignedVector; T tmp_sum[VecSize] = {static_cast(0)}; // calculate the dx and temporary sum @@ -321,7 +318,7 @@ __global__ void FusedResidualDropoutBiasGradVec( } // save temporary sum to cache and do transpose - __shared__ T cache[BLOCK_SIZE_X * VecSize][BLOCK_SIZE_Y]; + __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; for (int i = 0; i < VecSize; i++) { cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; } @@ -333,11 +330,11 @@ __global__ void FusedResidualDropoutBiasGradVec( int x = tid >> 5; // warp id int y = tid & 31; // thread id on warp 0~31 - // need BLOCK_SIZE_X * VecSize warps - if (x < BLOCK_SIZE_X * VecSize) { + // need BlockSizeX * VecSize warps + if (x < BlockSizeX * VecSize) { // reduce 128 to 32 #pragma unroll - for (int i = 0; i < (BLOCK_SIZE_Y >> 5); i++) { + for (int i = 0; i < (BlockSizeY >> 5); i++) { sum += cache[x][y + i * 32]; } } @@ -347,7 +344,7 @@ __global__ void FusedResidualDropoutBiasGradVec( // save sum to dbias int bias_id = blockIdx.x * blockDim.x * VecSize + x; - if (y == 0 && x < VecSize * BLOCK_SIZE_X && bias_id < cols) { + if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { dbias[bias_id] = sum; } } @@ -370,7 +367,7 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, factor = static_cast(1.0f); } - const int VecSize = 4; + const int VecSize = MAX_CACHE_BYTES / sizeof(T); if (dbias != nullptr) { int real_vec_size = VecSize; if (cols % VecSize != 0) { @@ -384,26 +381,27 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, dim3 grid_dim(blocks, 1, 1); if (cols % VecSize == 0) { - FusedResidualDropoutBiasGradVec< + FusedResidualDropoutBiasGrad< T, MaskType, 8, 128, VecSize><<>>( dout, mask, factor, rows, cols, dx, dbias); } else { - FusedResidualDropoutBiasGradVec< - T, MaskType, 8, 128, 1><<>>( + FusedResidualDropoutBiasGrad<<>>( dout, mask, factor, rows, cols, dx, dbias); } } else { const uint64_t n = rows * cols; - auto threads = Get1DThreadsAndBlocks(ctx, n); if (n % VecSize == 0) { - FusedResidualDropoutGradVec< - T, MaskType, - VecSize><<>>( + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx, n / VecSize); + FusedResidualDropoutGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( dout, mask, factor, n, dx); } else { - FusedResidualDropoutGradVec< - T, MaskType, 1><<>>( + platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx, n); + FusedResidualDropoutGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( dout, mask, factor, n, dx); } } diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index 14267974ff2657..b246d9bac9761b 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -39,7 +39,7 @@ struct TestFusedResidualDropoutBias { float dropout_prob; bool is_upscale_in_train; bool is_test; // default false, Set to true for inference only - bool hasbias = true; + bool has_bias = true; framework::Tensor src, residual, bias, out, mask; framework::Tensor dsrc, dbias; @@ -57,10 +57,10 @@ struct TestFusedResidualDropoutBias { dropout_prob = 0.0; is_upscale_in_train = false; is_test = false; - hasbias = true; + has_bias = true; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto devicectx = pool.Get(place); - ctx = reinterpret_cast(devicectx); + auto device_ctx = pool.Get(place); + ctx = reinterpret_cast(device_ctx); } TestFusedResidualDropoutBias(int rows_, int cols_, uint64_t seed_ = 0, @@ -73,10 +73,10 @@ struct TestFusedResidualDropoutBias { dropout_prob = dropout_prob_; is_upscale_in_train = is_upscale_in_train_; is_test = is_test_; - hasbias = true; + has_bias = true; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto devicectx = pool.Get(place); - ctx = reinterpret_cast(devicectx); + auto device_ctx = pool.Get(place); + ctx = reinterpret_cast(device_ctx); } ~TestFusedResidualDropoutBias() {} @@ -108,7 +108,7 @@ struct TestFusedResidualDropoutBias { src.Resize({rows, cols}); framework::TensorFromVector(residual_vec, *ctx, &residual); residual.Resize({rows, cols}); - if (hasbias) { + if (has_bias) { framework::TensorFromVector(bias_vec, *ctx, &bias); bias.Resize({cols}); } @@ -121,7 +121,7 @@ struct TestFusedResidualDropoutBias { dsrc.Resize({rows, cols}); dsrc.mutable_data(place); - if (hasbias) { + if (has_bias) { dbias.Resize({cols}); dbias.mutable_data(place); } @@ -130,7 +130,7 @@ struct TestFusedResidualDropoutBias { void BaseForward() { std::vector out1(rows * cols), out2(rows * cols); - if (hasbias) { + if (has_bias) { // add bias for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { @@ -167,15 +167,16 @@ struct TestFusedResidualDropoutBias { } void FusedForward() { - auto threads = paddle::operators::Get1DBlocksAnd2DGrids( - *ctx, (uint64_t)rows, (uint64_t)cols); - const int VecSize = 4; - const int increment = - ((cols - 1) / (threads.first.x * threads.second.x * VecSize) + 1) * - VecSize; + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + auto config = paddle::operators::Get1DBlocksAnd2DGrids( + *ctx, (uint64_t)rows, (uint64_t)cols, VecSize); + const int increment = ((cols - 1) / (config.thread_per_block.x * + config.block_per_grid.x * VecSize) + + 1) * + VecSize; T *bias_ptr = nullptr; - if (hasbias) { + if (has_bias) { bias_ptr = bias.data(); } paddle::operators::LaunchResidualDropoutBias( @@ -191,7 +192,7 @@ struct TestFusedResidualDropoutBias { } T *bias_ptr = nullptr; - if (hasbias) { + if (has_bias) { bias_ptr = dbias.data(); } paddle::operators::LaunchResidualDropoutBiasGrad( @@ -237,7 +238,7 @@ struct TestFusedResidualDropoutBias { EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff); } - if (hasbias) { + if (has_bias) { std::vector _dbias(cols); framework::TensorToVector(dbias, *ctx, &_dbias); ctx->Wait(); @@ -248,66 +249,39 @@ struct TestFusedResidualDropoutBias { } }; -TEST(FusedDropout, GPUFusedResidualDropoutBias) { +// test the shape and bias +template +static void BaseTest(const bool is_fp16 = false) { const int rows = 16; - const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + std::vector cols_list = {16, 17}; + bool has_bias[2] = {true, false}; + T default_diff = static_cast(1e-5); + if (is_fp16) { + default_diff = static_cast(1e-2); + } + for (int i = 0; i < cols_list.size(); i++) { + for (int j = 0; j < 2; j++) { + TestFusedResidualDropoutBias test(rows, cols_list[i]); + test.has_bias = has_bias[j]; + test.Run(); + test.CheckOut(default_diff); + if (!is_fp16) { + test.CheckGrad(default_diff); + } + } + } } -TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { - const int rows = 16; - const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); -} +TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest(); } + +TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest(); } // test fp16, For inference, check_grad is not required. ref: testdropout_op.py TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { - const int rows = 16; - const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols); - test.Run(); - test.CheckOut(static_cast(1e-2)); -} - -// test no bias and cols % 4 == 0 -TEST(FusedDropout, GPUFusedResidualDropoutBiasNoBias) { - const int rows = 16; - const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols); - test.hasbias = false; - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); -} - -// test no bias and cols % 4 != 0 -TEST(FusedDropout, GPUFusedResidualDropoutBiasNoBias2) { - const int rows = 16; - const int cols = 17; - TestFusedResidualDropoutBias test(rows, cols); - test.hasbias = false; - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + BaseTest(true); } -// test add bias and cols % 4 != 0 TEST(FusedDropout, GPUFusedResidualDropoutBias2) { - const int rows = 16; - const int cols = 17; - TestFusedResidualDropoutBias test(rows, cols); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); -} - -TEST(FusedDropout, GPUFusedResidualDropoutBias3) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); @@ -316,16 +290,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias3) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedResidualDropoutBias4) { - const int rows = 16; - const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); -} - -TEST(FusedDropout, GPUFusedResidualDropoutBias5) { +TEST(FusedDropout, GPUFusedResidualDropoutBias3) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, true, false); @@ -334,7 +299,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias5) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedResidualDropoutBias6) { +TEST(FusedDropout, GPUFusedResidualDropoutBias4) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 0.35, true, true); @@ -343,7 +308,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias6) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedResidualDropoutBias7) { +TEST(FusedDropout, GPUFusedResidualDropoutBias5) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 125, 0.0, false, false); From 934fcac6826d7559c71109df54ecdc7d3b89e81d Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Thu, 2 Sep 2021 10:26:33 +0000 Subject: [PATCH 10/18] use static_cast --- paddle/fluid/operators/fused/fused_dropout_common.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 53ce76826a6793..f159d16d855a7e 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -41,9 +41,11 @@ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( const uint32_t cols, const int VecSize) { const uint32_t tmp_cols = cols / VecSize; int threads = std::max( - (uint32_t)32, std::min(tmp_cols, (uint32_t)ctx.GetMaxThreadsPerBlock())); - int blocks_x = std::max((uint32_t)1, (tmp_cols + threads - 1) / threads); - int blocks_y = std::max((uint32_t)1, rows); + static_cast(32), + std::min(tmp_cols, static_cast(ctx.GetMaxThreadsPerBlock()))); + const auto blocks_x = + std::max(static_cast(1), (tmp_cols + threads - 1) / threads); + const auto blocks_y = std::max(static_cast(1), rows); platform::GpuLaunchConfig config; config.block_per_grid.x = blocks_x; config.block_per_grid.y = blocks_y; From 44610ea2e651be31fbde3bddecd95ced540c6665 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Wed, 8 Sep 2021 02:00:15 +0000 Subject: [PATCH 11/18] fix the blocks for large shape --- .../operators/fused/fused_residual_dropout_bias.h | 5 ++--- .../fused/fused_residual_dropout_bias_test.cu | 10 ++++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index bafc8c60040c1e..952042d45f47c4 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -374,9 +374,8 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, real_vec_size = 1; } auto threads = std::min(cols / real_vec_size, static_cast(8)); - auto blocks = std::max( - (uint32_t)1, std::min((cols / real_vec_size + threads - 1) / threads, - (uint32_t)ctx.GetSMCount())); + auto blocks = + std::max((uint32_t)1, cols / real_vec_size + threads - 1 / threads); dim3 block_dim(threads, 128, 1); dim3 grid_dim(blocks, 1, 1); diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index b246d9bac9761b..e687823bc8158b 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -316,3 +316,13 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias5) { test.CheckOut(static_cast(1e-5)); test.CheckGrad(static_cast(1e-5)); } + +// test large shape +TEST(FusedDropout, GPUFusedResidualDropoutBias6) { + const int rows = 256; + const int cols = 4096; + TestFusedResidualDropoutBias test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); +} From 1a83adb08e7cc1914b07b4e8ea3be6d521173a10 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Wed, 8 Sep 2021 03:43:53 +0000 Subject: [PATCH 12/18] merge upstream, and used new AlignedVector --- .../operators/fused/fused_dropout_common.h | 5 +- .../fused/fused_residual_dropout_bias.h | 252 ++++++------------ 2 files changed, 85 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index f159d16d855a7e..24f6f53c63630e 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -20,16 +20,17 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { -#define MAX_CACHE_BYTES 16 +#define CACHE_LINE 128 +#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT) /** * get the threads for fused_residual_dropout_bias: diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 952042d45f47c4..cd9dfd1c79ca8f 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -29,43 +29,45 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const int row_id, const int col_id, const int cols, curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, const T *src, const T *residual, const T *bias, T *dst, MaskType *mask, - typename details::MPTypeTrait::Type *mean_val, + const bool is_test, typename details::MPTypeTrait::Type *mean_val, typename details::MPTypeTrait::Type *var_val) { - using LoadT = platform::CudaAlignedVector; - using MaskLoadT = platform::CudaAlignedVector; + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskStoreT = platform::AlignedVector; using U = typename details::MPTypeTrait::Type; - T src_vec[VecSize]; - T residual_vec[VecSize]; - T bias_vec[VecSize]; + LoadT src_vec; + LoadT residual_vec; + LoadT bias_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { bias_vec[ii] = static_cast(0); } // vectorize load data from global - LoadT *value = reinterpret_cast(&src_vec); - LoadT *residual_value = reinterpret_cast(&residual_vec); - *value = *reinterpret_cast(&src[row_id * cols + col_id]); - *residual_value = - *reinterpret_cast(&residual[row_id * cols + col_id]); - - LoadT *bias_value = - bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; + platform::Load(&src[row_id * cols + col_id], &src_vec); + platform::Load(&residual[row_id * cols + col_id], &residual_vec); + if (bias) { - *bias_value = *reinterpret_cast(&bias[col_id]); + platform::Load(&bias[col_id], &bias_vec); } - float rand[VecSize]; - RandVec(state, rand, VecSize); - - T dest_vec[VecSize]; - MaskType mask_vec[VecSize]; - + MaskStoreT mask_vec; + if (!is_test) { + float rand[VecSize]; + RandVec(state, rand, VecSize); #pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); + } + } else { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(1); + } } + StoreT dest_vec; + #pragma unroll for (int ii = 0; ii < VecSize; ii++) { dest_vec[ii] = @@ -79,25 +81,25 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( } // store result to global - *(reinterpret_cast(&dst[row_id * cols + col_id])) = - *reinterpret_cast(&dest_vec[0]); - *(reinterpret_cast(&mask[row_id * cols + col_id])) = - *reinterpret_cast(&mask_vec[0]); + platform::Store(dest_vec, &dst[row_id * cols + col_id]); + if (!is_test) { + platform::Store(mask_vec, &mask[row_id * cols + col_id]); + } } /** * @brief dst = residual + dropout(src + bias); * the src, residual, mask and dst shape is (rows, cols) * the bias shape is (1, cols) + * is_test: only used in inference + * mask: can be null if is_test=true */ template -__global__ void FusedResidualDropoutBias(const size_t rows, const size_t cols, - uint64_t seed, - const float dropout_prob, - const bool is_upscale_in_train, - const T *src, const T *residual, - const T *bias, MaskType *mask, T *dst, - uint64_t increment) { +__global__ void FusedResidualDropoutBias( + const size_t rows, const size_t cols, uint64_t seed, + const float dropout_prob, const bool is_upscale_in_train, const T *src, + const T *residual, const T *bias, MaskType *mask, T *dst, + uint64_t increment, const bool is_test) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; @@ -108,89 +110,18 @@ __global__ void FusedResidualDropoutBias(const size_t rows, const size_t cols, if (!is_upscale_in_train) { factor = static_cast(1.0f); } - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < cols; - i += blockDim.x * gridDim.x * VecSize) { - FusedResidualDropoutBiasOneThread( - r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, - mask, nullptr, nullptr); - } - } -} - -/** - * @brief the fused function called by every thread - */ -template -__forceinline__ __device__ void FusedResidualDropoutBiasOnlyInferOneThread( - const int row_id, const int col_id, const int cols, - const float dropout_prob, const T factor, const T *src, const T *residual, - const T *bias, T *dst, typename details::MPTypeTrait::Type *mean_val, - typename details::MPTypeTrait::Type *var_val) { - using LoadT = platform::CudaAlignedVector; - using U = typename details::MPTypeTrait::Type; - T src_vec[VecSize]; - T residual_vec[VecSize]; - T bias_vec[VecSize]; -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - bias_vec[ii] = static_cast(0); - } - // vectorize load data from global - LoadT *value = reinterpret_cast(&src_vec); - LoadT *residual_value = reinterpret_cast(&residual_vec); - *value = *reinterpret_cast(&src[row_id * cols + col_id]); - *residual_value = - *reinterpret_cast(&residual[row_id * cols + col_id]); - - LoadT *bias_value = - bias != nullptr ? reinterpret_cast(&bias_vec) : nullptr; - if (bias) { - *bias_value = *reinterpret_cast(&bias[col_id]); - } - - T dest_vec[VecSize]; - -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - dest_vec[ii] = (src_vec[ii] + bias_vec[ii]) * factor + residual_vec[ii]; - if (ComputeLayerNorm) { - U tmp = static_cast(dest_vec[ii]); - *mean_val += tmp; - *var_val += (tmp * tmp); + if (is_test) { + factor = static_cast(1.0f - dropout_prob); + if (is_upscale_in_train) { + factor = static_cast(1.0f); } } - - // store result to global - *(reinterpret_cast(&dst[row_id * cols + col_id])) = - *reinterpret_cast(&dest_vec[0]); -} - -/** - * @brief for dropout's param is_test = true, only used in inference - * the src, residual and dst shape is (rows, cols) - * the bias shape is (1, cols) - */ -template -__global__ void FusedResidualDropoutBiasOnlyInfer( - const size_t rows, const size_t cols, const float dropout_prob, - const bool is_upscale_in_train, const T *src, const T *residual, - const T *bias, T *dst) { - int col_id = blockDim.x * blockIdx.x + threadIdx.x; - int row_id = blockIdx.y; - int idx = row_id * cols + col_id; - - T factor = static_cast(1.0f - dropout_prob); - if (is_upscale_in_train) { - factor = static_cast(1.0f); - } - for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - FusedResidualDropoutBiasOnlyInferOneThread( - r, i, cols, dropout_prob, factor, src, residual, bias, dst, nullptr, - nullptr); + FusedResidualDropoutBiasOneThread( + r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, + mask, is_test, nullptr, nullptr); } } } @@ -212,37 +143,27 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), ctx.stream()); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( - mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + if (!is_test) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( + mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + } return; } const int VecSize = MAX_CACHE_BYTES / sizeof(T); - auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, VecSize); - if (cols % VecSize != 0) { - if (!is_test) { - FusedResidualDropoutBias<<< - config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( - rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, - bias, mask_data, dst, increment); - } else { - FusedResidualDropoutBiasOnlyInfer<<< - config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( - rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, - dst); - } + const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; + auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); + if (cols % VecSize == 0) { + FusedResidualDropoutBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment, is_test); } else { - if (!is_test) { - FusedResidualDropoutBias<<< - config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( - rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, - bias, mask_data, dst, increment); - } else { - FusedResidualDropoutBiasOnlyInfer<<< - config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( - rows, cols, dropout_prob, is_upscale_in_train, src, residual, bias, - dst); - } + FusedResidualDropoutBias< + T, uint8_t, + 1><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment, is_test); } } @@ -255,23 +176,21 @@ __global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask, T *dx) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - using LoadT = platform::CudaAlignedVector; - using MaskLoadT = platform::CudaAlignedVector; + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { - T dout_vec[VecSize]; - MaskType mask_vec[VecSize]; - LoadT *dout_value = reinterpret_cast(&dout_vec); - MaskLoadT *mask_value = reinterpret_cast(&mask_vec); - *dout_value = *reinterpret_cast(&dout[i]); - *mask_value = *reinterpret_cast(&mask[i]); - - T dx_vec[VecSize]; + LoadT dout_vec; + MaskLoadT mask_vec; + platform::Load(&dout[i], &dout_vec); + platform::Load(&mask[i], &mask_vec); + + StoreT dx_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { dx_vec[ii] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; } - *(reinterpret_cast(&dx[i])) = - *reinterpret_cast(&dx_vec[0]); + platform::Store(dx_vec, &dx[i]); } } @@ -290,21 +209,20 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, T *dbias) { int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; - using LoadT = platform::CudaAlignedVector; - using MaskLoadT = platform::CudaAlignedVector; + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; T tmp_sum[VecSize] = {static_cast(0)}; // calculate the dx and temporary sum if (col_id * VecSize < cols) { for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { int index = row_id * cols + col_id * VecSize; - T out_vec[VecSize]; - MaskType mask_vec[VecSize]; - T dx_vec[VecSize]; - LoadT *out_value = reinterpret_cast(&out_vec); - MaskLoadT *mask_value = reinterpret_cast(&mask_vec); - *out_value = *reinterpret_cast(&dout[index]); - *mask_value = *reinterpret_cast(&mask[index]); + LoadT out_vec; + MaskLoadT mask_vec; + StoreT dx_vec; + platform::Load(&dout[index], &out_vec); + platform::Load(&mask[index], &mask_vec); #pragma unroll for (int i = 0; i < VecSize; i++) { @@ -312,8 +230,7 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, tmp_sum[i] += out_vec[i]; } - *(reinterpret_cast(&dx[index])) = - *reinterpret_cast(&dx_vec[0]); + platform::Store(dx_vec, &dx[index]); } } @@ -368,17 +285,13 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, } const int VecSize = MAX_CACHE_BYTES / sizeof(T); + int real_vec_size = cols % VecSize == 0 ? VecSize : 1; if (dbias != nullptr) { - int real_vec_size = VecSize; - if (cols % VecSize != 0) { - real_vec_size = 1; - } auto threads = std::min(cols / real_vec_size, static_cast(8)); auto blocks = - std::max((uint32_t)1, cols / real_vec_size + threads - 1 / threads); + std::max((uint32_t)1, (cols / real_vec_size + threads - 1) / threads); dim3 block_dim(threads, 128, 1); dim3 grid_dim(blocks, 1, 1); - if (cols % VecSize == 0) { FusedResidualDropoutBiasGrad< T, MaskType, 8, 128, @@ -391,14 +304,13 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, } } else { const uint64_t n = rows * cols; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); if (n % VecSize == 0) { - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(ctx, n / VecSize); FusedResidualDropoutGrad<<< config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( dout, mask, factor, n, dx); } else { - platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(ctx, n); FusedResidualDropoutGrad<<< config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( dout, mask, factor, n, dx); From 4dba815f0496ece41da0661b170f467b1f77e619 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Wed, 8 Sep 2021 11:44:59 +0000 Subject: [PATCH 13/18] add a fusion op: fused_dropout_act_bias --- paddle/fluid/operators/fused/CMakeLists.txt | 1 + .../operators/fused/fused_dropout_act_bias.h | 311 ++++++++++++++++ .../fused/fused_dropout_act_bias_test.cu | 347 ++++++++++++++++++ .../operators/fused/fused_dropout_common.h | 47 +-- .../fused/fused_residual_dropout_bias.h | 20 +- .../fused/fused_residual_dropout_bias_test.cu | 49 ++- 6 files changed, 708 insertions(+), 67 deletions(-) create mode 100755 paddle/fluid/operators/fused/fused_dropout_act_bias.h create mode 100755 paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 3df2144aa35944..0a12735acf2a05 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -75,5 +75,6 @@ if (WITH_GPU OR WITH_ROCM) # only support CUDA if(NOT WITH_ROCM) nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) + nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) endif() endif() diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h new file mode 100755 index 00000000000000..a348fba335a140 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -0,0 +1,311 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/fused_dropout_common.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/fluid/operators/math/functors.h" + +namespace paddle { +namespace operators { + +typedef platform::float16 fp16; + +/** + * @brief dst = dropout(activation(src + bias)); + * the src, mask and dst shape is (rows, cols) + * the bias shape is (1, cols) + */ +template +__global__ void FusedDropoutActBias(Functor act, const uint64_t seed, + const uint64_t rows, const uint64_t cols, + const int increment, + const float dropout_prob, + const bool is_upscale_in_train, + const bool is_test, const T *src, + const T *bias, T *dst, MaskType *mask) { + int col_id = blockDim.x * blockIdx.x + threadIdx.x; + int row_id = blockIdx.y; + int idx = row_id * cols + col_id; + + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); + + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0); + } + if (is_test) { + factor = static_cast(1.0f - dropout_prob); + if (is_upscale_in_train) { + factor = static_cast(1.0f); + } + } + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + using MaskStoreT = platform::AlignedVector; + + const int tmp_cols = cols / VecSize * VecSize; + for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { + for (int i = col_id * VecSize; i < tmp_cols; + i += blockDim.x * gridDim.x * VecSize) { + LoadT src_vec; + LoadT bias_vec; + // vectorize load data from global + platform::Load(&src[r * cols + i], &src_vec); + + if (bias) { + platform::Load(&bias[i], &bias_vec); + } else { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + bias_vec[ii] = static_cast(0); + } + } + + MaskStoreT mask_vec; + if (!is_test) { + float rand[VecSize]; + RandVec(&state, rand); +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); + } + } else { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(1); + } + } + + StoreT dest_vec; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + const T tmp = src_vec[ii] + bias_vec[ii]; + dest_vec[ii] = act(tmp) * static_cast(mask_vec[ii]) * factor; + } + // store result to global + platform::Store(dest_vec, &dst[r * cols + i]); + platform::Store(mask_vec, &mask[r * cols + i]); + } + } +} + +/** + * @brief dst = dropout(activation(src + bias)); + */ +template +void LaunchDropoutActBias(Functor act_functor, const uint64_t seed, + const uint32_t rows, const uint32_t cols, + const int increment, const float dropout_prob, + const bool is_upscale_in_train, const bool is_test, + const T *src, const T *bias, T *dst, + MaskType *mask_data, + const platform::CUDADeviceContext &ctx) { + // dropout_prob == 1.0f + if (std::abs(dropout_prob - 1.0f) < 1e-5) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemsetAsync(dst, 0, rows * cols * sizeof(T), ctx.stream())); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( + mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + return; + } + + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; + const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); + if (cols % VecSize == 0) { + FusedDropoutActBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, seed, rows, cols, increment, dropout_prob, + is_upscale_in_train, is_test, src, bias, dst, mask_data); + } else { + FusedDropoutActBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, seed, rows, cols, increment, dropout_prob, + is_upscale_in_train, is_test, src, bias, dst, mask_data); + } +} + +/* + * @brief calculate the grad of no bias + */ +template +__global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, + const MaskType *mask, const T *src, + const T factor, const int64_t size, T *dx) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { + LoadT dout_vec; + LoadT src_vec; + MaskLoadT mask_vec; + + platform::Load(&dout[i], &dout_vec); + platform::Load(&mask[i], &mask_vec); + platform::Load(&src[i], &src_vec); + + StoreT dx_vec; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + T x = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + T out = src_vec[ii]; + dx_vec[ii] = act_grad.UseXAndOut(x, out); + } + platform::Store(dx_vec, &dx[i]); + } +} + +/** + * blocks(128 * 8) + * 1. calculate the dx and reduce total rows to 128 rows + * 2. save 128*8 temporary sum in 8*128 shared memory + * 3. reduce the sum of 128 rows data by 8*VecSize warps + */ +template +__global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, + const MaskType *mask, const T *src, + const T *bias, const T factor, + const int64_t rows, const int64_t cols, + T *dx, T *dbias) { + int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + T tmp_sum[VecSize] = {static_cast(0)}; + // calculate the dx and temporary sum + if (col_id * VecSize < cols) { + for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { + int index = row_id * cols + col_id * VecSize; + LoadT dout_vec; + LoadT src_vec; + LoadT bias_vec; + MaskLoadT mask_vec; + + platform::Load(&dout[index], &dout_vec); + platform::Load(&src[index], &src_vec); + platform::Load(&mask[index], &mask_vec); + platform::Load(&bias[col_id * VecSize], &bias_vec); + + StoreT dx_vec; +#pragma unroll + for (int i = 0; i < VecSize; i++) { + T val; + T x = dout_vec[i] * static_cast(mask_vec[i]) * factor; + T out = src_vec[i] + bias_vec[i]; + val = act_grad.UseXAndOut(x, out); + dx_vec[i] = val; + tmp_sum[i] += val; + } + platform::Store(dx_vec, &dx[index]); + } + } + + __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; + for (int i = 0; i < VecSize; i++) { + cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; + } + __syncthreads(); + + // reduce sum + T sum = static_cast(0); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 5; // warp id + int y = tid & 31; // thread id on warp 0~31 + + // need BlockSizeX * VecSize warps + if (x < BlockSizeX * VecSize) { +// reduce 128 to 32 +#pragma unroll + for (int i = 0; i < (BlockSizeY >> 5); i++) { + sum += cache[x][y + i * 32]; + } + } + + // reduce 32 to 1 + sum = WarpReduceSum(sum); + + // save sum to dbias + int bias_id = blockIdx.x * blockDim.x * VecSize + x; + if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { + dbias[bias_id] = sum; + } +} + +/** + * @brief to launch kernel FusedResidualDropoutBiasGradVec + */ +template +void LaunchDropoutActBiasGrad(Functor act_functor, const T *dout, + const MaskType *mask, const T *src, const T *bias, + const float dropout_prob, + const bool is_upscale_in_train, + const uint32_t rows, const uint32_t cols, T *dx, + T *dbias, + const platform::CUDADeviceContext &ctx) { + const T zero = static_cast(0.0); + auto factor = dropout_prob == static_cast(1.0f) + ? zero + : static_cast(1.0 / (1.0 - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0f); + } + + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + int real_vec_size = cols % VecSize == 0 ? VecSize : 1; + + if (dbias != nullptr) { + const auto threads = 8; + const auto blocks = + std::max(static_cast(1), + (cols / real_vec_size + threads - 1) / threads); + dim3 block_dim(threads, 128, 1); + dim3 grid_dim(blocks, 1, 1); + if (cols % VecSize == 0) { + FusedDropoutActBiasGrad< + T, MaskType, 8, 128, VecSize, + Functor><<>>( + act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias); + } else { + FusedDropoutActBiasGrad< + T, MaskType, 8, 128, 1, + Functor><<>>( + act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias); + } + } else { + const uint64_t n = rows * cols; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); + if (n % VecSize == 0) { + FusedDropoutActGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, dout, mask, src, factor, n, dx); + } else { + FusedDropoutActGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, dout, mask, src, factor, n, dx); + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu new file mode 100755 index 00000000000000..cb1f8ac938e252 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu @@ -0,0 +1,347 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include + +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" +#include "paddle/fluid/operators/fused/fused_dropout_test.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace details = paddle::operators::details; +namespace math = paddle::operators::math; + +/** + * @brief the unittest of fused_dropout_act_bias + * 1. random input data + * 2. add bias, call activation, call paddle dropout, and get the base result + * 3. call FusedDropoutActBias function get fused result + * 4. compare ther base result and fused result + */ + +template +struct TestFusedDropoutActBias { + uint32_t rows; + uint32_t cols; + uint64_t seed; + float dropout_prob; + bool is_upscale_in_train; + bool is_test; // default false, Set to true for inference only + bool has_bias = true; + framework::Tensor src, bias, out, mask; + framework::Tensor dsrc, dbias; + + std::vector src_vec, bias_vec, out_vec, mask_vec; + std::vector correct_out, correct_dsrc, correct_dbias; + std::vector correct_mask; + + platform::CUDAPlace place; + platform::CUDADeviceContext *ctx; + + TestFusedDropoutActBias() { + rows = 32; + cols = 32; + seed = 0; + dropout_prob = 0.0; + is_upscale_in_train = false; + is_test = false; + has_bias = true; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); + } + + TestFusedDropoutActBias(int rows_, int cols_, uint64_t seed_ = 0, + float dropout_prob_ = 0.0, + bool is_upscale_in_train_ = false, + bool is_test_ = false) { + rows = rows_; + cols = cols_; + seed = seed_; + dropout_prob = dropout_prob_; + is_upscale_in_train = is_upscale_in_train_; + is_test = is_test_; + has_bias = true; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); + } + + ~TestFusedDropoutActBias() {} + + void SetUp() { + const int n = rows * cols; + correct_out.resize(n); + correct_mask.resize(n); + correct_dsrc.resize(n); + correct_dbias.resize(cols); + + src_vec.resize(n); + bias_vec.resize(cols); + std::default_random_engine random(time(NULL)); + std::uniform_real_distribution dis(0.0, 1.0); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + src_vec[i * cols + j] = static_cast(dis(random)); + if (i == 0) bias_vec[j] = dis(random); + } + } + + framework::TensorFromVector(src_vec, *ctx, &src); + src.Resize({rows, cols}); + if (has_bias) { + framework::TensorFromVector(bias_vec, *ctx, &bias); + bias.Resize({cols}); + } + + { + out.Resize({rows, cols}); + out.mutable_data(place); + mask.Resize({rows, cols}); + mask.mutable_data(place); + dsrc.Resize({rows, cols}); + dsrc.mutable_data(place); + + if (has_bias) { + dbias.Resize({cols}); + dbias.mutable_data(place); + } + } + } + + void BaseForward() { + std::vector out1(rows * cols); + Functor act; + if (has_bias) { + // add bias and call activation + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + const T tmp = src_vec[i * cols + j] + bias_vec[j]; + out1[i * cols + j] = act(tmp); + } + } + // call dropout + Dropout(out1, src.dims(), &correct_out, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } else { + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + const T tmp = src_vec[i * cols + j]; + out1[i * cols + j] = act(tmp); + } + } + + Dropout(out1, src.dims(), &correct_out, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } + ctx->Wait(); + } + + void BaseBackward() { + std::vector _out(rows * cols); + // call dropout_grad + DropoutGrad(&_out, src.dims(), correct_out, correct_mask, *ctx, + dropout_prob, is_upscale_in_train); + + // calculate dbias + memset(&correct_dbias[0], 0, cols * sizeof(T)); + GradFunctor act_grad; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + if (has_bias) { + T x = _out[i * cols + j]; + T out = src_vec[i * cols + j] + bias_vec[j]; + T val = act_grad.UseXAndOut(x, out); + correct_dbias[j] += val; + correct_dsrc[i * cols + j] = val; + } else { + T val = + act_grad.UseXAndOut(_out[i * cols + j], src_vec[i * cols + j]); + correct_dsrc[i * cols + j] = val; + } + } + } + } + + void FusedForward() { + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + auto config = paddle::operators::Get1DBlocksAnd2DGrids( + *ctx, static_cast(rows), static_cast(cols), + VecSize); + const int increment = ((cols - 1) / (config.thread_per_block.x * + config.block_per_grid.x * VecSize) + + 1) * + VecSize; + + T *bias_ptr = nullptr; + if (has_bias) { + bias_ptr = bias.data(); + } + Functor act; + paddle::operators::LaunchDropoutActBias( + act, seed, rows, cols, increment, dropout_prob, is_upscale_in_train, + is_test, src.data(), bias_ptr, out.data(), mask.data(), + *ctx); + ctx->Wait(); + } + + void FusedBackward() { + if (is_test) return; + + T *bias_ptr = nullptr; + T *dbias_ptr = nullptr; + if (has_bias) { + dbias_ptr = dbias.data(); + bias_ptr = bias.data(); + } + GradFunctor act_grad; + paddle::operators::LaunchDropoutActBiasGrad( + act_grad, out.data(), mask.data(), src.data(), bias_ptr, + dropout_prob, is_upscale_in_train, rows, cols, dsrc.data(), + dbias_ptr, *ctx); + } + + void Run() { + SetUp(); + BaseForward(); + FusedForward(); + BaseBackward(); + FusedBackward(); + } + + void CheckOut(const T diff) { + const int n = rows * cols; + std::vector _out(n); + std::vector _mask(n); + framework::TensorToVector(out, *ctx, &_out); + if (!is_test) { + framework::TensorToVector(mask, *ctx, &_mask); + } + ctx->Wait(); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); + if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); + } + } + + void CheckGrad(const T diff) { + if (is_test) return; + + const int n = rows * cols; + + std::vector _dsrc(n); + framework::TensorToVector(dsrc, *ctx, &_dsrc); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff); + } + + if (has_bias) { + std::vector _dbias(cols); + framework::TensorToVector(dbias, *ctx, &_dbias); + ctx->Wait(); + for (int i = 0; i < cols; i++) { + EXPECT_LT(std::abs(_dbias[i] - correct_dbias[i]), diff); + } + } + } +}; + +template +static void BaseTest() {} +// test the shape , bias, activation +template +static void BaseTest(const bool is_fp16 = false) { + const int rows = 16; + std::vector cols_list = {16, 17}; + bool has_bias[2] = {true, false}; + T default_diff = !is_fp16 ? static_cast(1e-5) : default_diff = + static_cast(1e-2); + for (auto cols : {16, 17}) { + for (auto has_bias : {true, false}) { + TestFusedDropoutActBias test(rows, cols); + test.has_bias = has_bias; + test.Run(); + test.CheckOut(default_diff); + test.CheckGrad(default_diff); + } + } +} + +TEST(FusedDropout, GPUFusedDorpoutActBias) { + BaseTest, math::ReluGradFunctor>(); + BaseTest, math::GeluGradFunctor>(); +} +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasDouble) { + BaseTest, math::ReluGradFunctor>(); + BaseTest, math::GeluGradFunctor>(); +} + +// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasFp16) { + using fp16 = platform::float16; + BaseTest, math::ReluGradFunctor>(true); +} + +TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) { + const int rows = 16; + const int cols = 16; + for (auto is_upscale_in_train : {true, false}) { + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols, 0, 1.0, is_upscale_in_train, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); + } +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasIsTest) { + const int rows = 16; + const int cols = 16; + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols, 0, 0.35, true, true); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasSeed) { + const int rows = 16; + const int cols = 16; + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols, 125, 0.0, false, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedRedisualDorpoutBiasLargeShape) { + const int rows = 256; + const int cols = 4096; + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 24f6f53c63630e..3e4200a717d4eb 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -39,8 +39,8 @@ namespace operators { */ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( const platform::CUDADeviceContext &ctx, const uint32_t rows, - const uint32_t cols, const int VecSize) { - const uint32_t tmp_cols = cols / VecSize; + const uint32_t cols, const int vec_size) { + const uint32_t tmp_cols = cols / vec_size; int threads = std::max( static_cast(32), std::min(tmp_cols, static_cast(ctx.GetMaxThreadsPerBlock()))); @@ -54,19 +54,26 @@ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( return config; } -__forceinline__ __device__ void Rand1(curandStatePhilox4_32_10_t *state, - float *data) { +template +__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, + float *data); + +template <> +__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state, + float *data) { data[0] = curand_uniform(state); } -__forceinline__ __device__ void Rand2(curandStatePhilox4_32_10_t *state, - float *data) { +template <> +__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state, + float *data) { data[0] = curand_uniform(state); data[1] = curand_uniform(state); } -__forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state, - float *data) { +template <> +__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state, + float *data) { float4 rand4 = curand_uniform4(state); data[0] = rand4.x; data[1] = rand4.y; @@ -74,25 +81,11 @@ __forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state, data[3] = rand4.z; } -__forceinline__ __device__ void Rand8(curandStatePhilox4_32_10_t *state, - float *data) { - Rand4(state, data); - Rand4(state, data + 4); -} - -__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, - float *data, const int VecSize) { - if (VecSize == 1) { - Rand1(state, data); - } else if (VecSize == 2) { - Rand2(state, data); - } else if (VecSize == 4) { - Rand4(state, data); - } else if (VecSize == 8) { - Rand8(state, data); - } else { - return; - } +template <> +__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state, + float *data) { + RandVec<4>(state, data); + RandVec<4>(state, data + 4); } } // namespace operators diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index cd9dfd1c79ca8f..3ba91060a4e356 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -54,7 +54,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( MaskStoreT mask_vec; if (!is_test) { float rand[VecSize]; - RandVec(state, rand, VecSize); + RandVec(state, rand); #pragma unroll for (int ii = 0; ii < VecSize; ii++) { mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); @@ -106,15 +106,11 @@ __global__ void FusedResidualDropoutBias( curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); - T factor = static_cast(1.0f / (1.0f - dropout_prob)); - if (!is_upscale_in_train) { - factor = static_cast(1.0f); - } + T factor = is_upscale_in_train ? static_cast(1.0f / (1.0f - dropout_prob)) + : static_cast(1.0f); if (is_test) { - factor = static_cast(1.0f - dropout_prob); - if (is_upscale_in_train) { - factor = static_cast(1.0f); - } + factor = is_upscale_in_train ? static_cast(1.0f) + : static_cast(1.0f - dropout_prob); } for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; @@ -287,9 +283,9 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, const int VecSize = MAX_CACHE_BYTES / sizeof(T); int real_vec_size = cols % VecSize == 0 ? VecSize : 1; if (dbias != nullptr) { - auto threads = std::min(cols / real_vec_size, static_cast(8)); - auto blocks = - std::max((uint32_t)1, (cols / real_vec_size + threads - 1) / threads); + const auto threads = 8; + auto blocks = std::max(static_cast(1), + (cols / real_vec_size + threads - 1) / threads); dim3 block_dim(threads, 128, 1); dim3 grid_dim(blocks, 1, 1); if (cols % VecSize == 0) { diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index e687823bc8158b..9d596a44eaaa69 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -169,7 +169,8 @@ struct TestFusedResidualDropoutBias { void FusedForward() { const int VecSize = MAX_CACHE_BYTES / sizeof(T); auto config = paddle::operators::Get1DBlocksAnd2DGrids( - *ctx, (uint64_t)rows, (uint64_t)cols, VecSize); + *ctx, static_cast(rows), static_cast(cols), + VecSize); const int increment = ((cols - 1) / (config.thread_per_block.x * config.block_per_grid.x * VecSize) + 1) * @@ -255,17 +256,17 @@ static void BaseTest(const bool is_fp16 = false) { const int rows = 16; std::vector cols_list = {16, 17}; bool has_bias[2] = {true, false}; - T default_diff = static_cast(1e-5); - if (is_fp16) { - default_diff = static_cast(1e-2); - } - for (int i = 0; i < cols_list.size(); i++) { - for (int j = 0; j < 2; j++) { - TestFusedResidualDropoutBias test(rows, cols_list[i]); - test.has_bias = has_bias[j]; + T default_diff = !is_fp16 ? static_cast(1e-5) : default_diff = + static_cast(1e-2); + for (auto cols : {16, 17}) { + for (auto has_bias : {true, false}) { + TestFusedResidualDropoutBias test(rows, cols); + test.has_bias = has_bias; test.Run(); test.CheckOut(default_diff); if (!is_fp16) { + // test fp16, For inference, check_grad is not required. ref: + // testdropout_op.py test.CheckGrad(default_diff); } } @@ -276,30 +277,23 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest(); } TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest(); } -// test fp16, For inference, check_grad is not required. ref: testdropout_op.py TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { BaseTest(true); } -TEST(FusedDropout, GPUFusedResidualDropoutBias2) { - const int rows = 16; - const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); -} - -TEST(FusedDropout, GPUFusedResidualDropoutBias3) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { const int rows = 16; const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, true, false); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + for (auto is_upscale_in_train : {true, false}) { + TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, + is_upscale_in_train, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); + } } -TEST(FusedDropout, GPUFusedResidualDropoutBias4) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 0.35, true, true); @@ -308,7 +302,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias4) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedResidualDropoutBias5) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 125, 0.0, false, false); @@ -317,8 +311,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias5) { test.CheckGrad(static_cast(1e-5)); } -// test large shape -TEST(FusedDropout, GPUFusedResidualDropoutBias6) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) { const int rows = 256; const int cols = 4096; TestFusedResidualDropoutBias test(rows, cols); From f848739c0f890e99d18594b2bd765b57adca966a Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Thu, 9 Sep 2021 02:11:29 +0000 Subject: [PATCH 14/18] remove unused code --- .../fluid/operators/fused/fused_residual_dropout_bias_test.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index 9d596a44eaaa69..47546d5f9ca13a 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -254,8 +254,6 @@ struct TestFusedResidualDropoutBias { template static void BaseTest(const bool is_fp16 = false) { const int rows = 16; - std::vector cols_list = {16, 17}; - bool has_bias[2] = {true, false}; T default_diff = !is_fp16 ? static_cast(1e-5) : default_diff = static_cast(1e-2); for (auto cols : {16, 17}) { From b8a986198167733f13fd203828bccfcb6f5d880c Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Thu, 9 Sep 2021 05:39:13 +0000 Subject: [PATCH 15/18] redefine activation functor --- .../operators/fused/fused_dropout_act_bias.h | 59 ++++++++++++-- .../fused/fused_dropout_act_bias_test.cu | 76 ++++++++++--------- 2 files changed, 94 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index a348fba335a140..cd9c66d9f79db6 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -23,6 +23,49 @@ namespace operators { typedef platform::float16 fp16; +/** + *@brief the relu functor + */ +template +struct ReluFunctor { + __host__ __device__ T operator()(const T *args) const { + math::ReluFunctor relu; + return relu(args[0]); + } +}; + +template +struct ReluGradFunctor { + __host__ __device__ __forceinline__ T operator()(const T *args) const { + math::ReluGradFunctor relu_grad; + return args[0] * relu_grad.UseOut(args[1]); + } +}; + +/** + *@brief the gelu functor + */ +template +struct GeluFunctor { + __host__ __device__ T operator()(const T *args) const { + math::GeluFunctor gelu; + return gelu(args[0]); + } +}; + +/** + *@brief the gelu grad functor + */ +template +struct GeluGradFunctor { + __host__ __device__ T operator()(const T *args) const { + const T grad = args[0]; + const T x = args[1]; + math::GeluGradFunctor gelu_grad; + return grad * gelu_grad.UseOut(x); + } +}; + /** * @brief dst = dropout(activation(src + bias)); * the src, mask and dst shape is (rows, cols) @@ -96,7 +139,7 @@ __global__ void FusedDropoutActBias(Functor act, const uint64_t seed, #pragma unroll for (int ii = 0; ii < VecSize; ii++) { const T tmp = src_vec[ii] + bias_vec[ii]; - dest_vec[ii] = act(tmp) * static_cast(mask_vec[ii]) * factor; + dest_vec[ii] = act(&tmp) * static_cast(mask_vec[ii]) * factor; } // store result to global platform::Store(dest_vec, &dst[r * cols + i]); @@ -165,9 +208,10 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, StoreT dx_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { - T x = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; - T out = src_vec[ii]; - dx_vec[ii] = act_grad.UseXAndOut(x, out); + T args[2]; + args[0] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + args[1] = src_vec[ii]; + dx_vec[ii] = act_grad(args); } platform::Store(dx_vec, &dx[i]); } @@ -210,9 +254,10 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, #pragma unroll for (int i = 0; i < VecSize; i++) { T val; - T x = dout_vec[i] * static_cast(mask_vec[i]) * factor; - T out = src_vec[i] + bias_vec[i]; - val = act_grad.UseXAndOut(x, out); + T args[2]; + args[0] = dout_vec[i] * static_cast(mask_vec[i]) * factor; + args[1] = src_vec[i] + bias_vec[i]; + val = act_grad(args); dx_vec[i] = val; tmp_sum[i] += val; } diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu index cb1f8ac938e252..f33e4d020b2954 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu @@ -24,7 +24,7 @@ limitations under the License. */ namespace framework = paddle::framework; namespace platform = paddle::platform; namespace details = paddle::operators::details; -namespace math = paddle::operators::math; +namespace operators = paddle::operators; /** * @brief the unittest of fused_dropout_act_bias @@ -133,7 +133,7 @@ struct TestFusedDropoutActBias { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { const T tmp = src_vec[i * cols + j] + bias_vec[j]; - out1[i * cols + j] = act(tmp); + out1[i * cols + j] = act(&tmp); } } // call dropout @@ -143,7 +143,7 @@ struct TestFusedDropoutActBias { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { const T tmp = src_vec[i * cols + j]; - out1[i * cols + j] = act(tmp); + out1[i * cols + j] = act(&tmp); } } @@ -165,14 +165,17 @@ struct TestFusedDropoutActBias { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { if (has_bias) { - T x = _out[i * cols + j]; - T out = src_vec[i * cols + j] + bias_vec[j]; - T val = act_grad.UseXAndOut(x, out); + T args[2]; + args[0] = _out[i * cols + j]; + args[1] = src_vec[i * cols + j] + bias_vec[j]; + T val = act_grad(args); correct_dbias[j] += val; correct_dsrc[i * cols + j] = val; } else { - T val = - act_grad.UseXAndOut(_out[i * cols + j], src_vec[i * cols + j]); + T args[2]; + args[0] = _out[i * cols + j]; + args[1] = src_vec[i * cols + j]; + T val = act_grad(args); correct_dsrc[i * cols + j] = val; } } @@ -264,15 +267,13 @@ struct TestFusedDropoutActBias { } }; -template -static void BaseTest() {} // test the shape , bias, activation template static void BaseTest(const bool is_fp16 = false) { const int rows = 16; std::vector cols_list = {16, 17}; bool has_bias[2] = {true, false}; - T default_diff = !is_fp16 ? static_cast(1e-5) : default_diff = + T default_diff = !is_fp16 ? static_cast(1e-3) : default_diff = static_cast(1e-2); for (auto cols : {16, 17}) { for (auto has_bias : {true, false}) { @@ -280,68 +281,75 @@ static void BaseTest(const bool is_fp16 = false) { test.has_bias = has_bias; test.Run(); test.CheckOut(default_diff); - test.CheckGrad(default_diff); + if (!is_fp16) { + test.CheckGrad(default_diff); + } } } } TEST(FusedDropout, GPUFusedDorpoutActBias) { - BaseTest, math::ReluGradFunctor>(); - BaseTest, math::GeluGradFunctor>(); + BaseTest, + paddle::operators::ReluGradFunctor>(); + BaseTest, + operators::GeluGradFunctor>(); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasDouble) { - BaseTest, math::ReluGradFunctor>(); - BaseTest, math::GeluGradFunctor>(); +TEST(FusedDropout, GPUFusedDropoutActBiasDouble) { + BaseTest, + operators::ReluGradFunctor>(); + BaseTest, + operators::GeluGradFunctor>(); } // test fp16, For inference, check_grad is not required. ref: test_dropout_op.py -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasFp16) { +TEST(FusedDropout, GPUFusedDropoutActBiasFp16) { using fp16 = platform::float16; - BaseTest, math::ReluGradFunctor>(true); + BaseTest, + operators::ReluGradFunctor>(true); } TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) { const int rows = 16; const int cols = 16; for (auto is_upscale_in_train : {true, false}) { - TestFusedDropoutActBias, - math::ReluGradFunctor> + TestFusedDropoutActBias, + operators::ReluGradFunctor> test(rows, cols, 0, 1.0, is_upscale_in_train, false); test.Run(); test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); } } -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasIsTest) { +TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) { const int rows = 16; const int cols = 16; - TestFusedDropoutActBias, - math::ReluGradFunctor> + TestFusedDropoutActBias, + operators::ReluGradFunctor> test(rows, cols, 0, 0.35, true, true); test.Run(); test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasSeed) { +TEST(FusedDropout, GPUFusedDropoutActBiasSeed) { const int rows = 16; const int cols = 16; - TestFusedDropoutActBias, - math::ReluGradFunctor> + TestFusedDropoutActBias, + operators::ReluGradFunctor> test(rows, cols, 125, 0.0, false, false); test.Run(); test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); } -TEST(FusedDropout, GPUFusedRedisualDorpoutBiasLargeShape) { +TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) { const int rows = 256; const int cols = 4096; - TestFusedDropoutActBias, - math::ReluGradFunctor> + TestFusedDropoutActBias, + operators::ReluGradFunctor> test(rows, cols); test.Run(); test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); } From fd01daa1a87465ec7d29698f2d261ceee85fe07f Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Thu, 9 Sep 2021 08:26:43 +0000 Subject: [PATCH 16/18] implement the same gelu as the baseline for FFN --- .../operators/fused/fused_dropout_act_bias.h | 25 +++++++++++++------ .../operators/fused/fused_dropout_common.h | 3 +-- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index cd9c66d9f79db6..a6351a3910f6cf 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -48,8 +48,11 @@ struct ReluGradFunctor { template struct GeluFunctor { __host__ __device__ T operator()(const T *args) const { - math::GeluFunctor gelu; - return gelu(args[0]); + using U = LayerNormParamType; + U casted_x = static_cast(args[0]); + auto temp = erf(casted_x * static_cast(M_SQRT1_2)); + auto out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); + return static_cast(out); } }; @@ -59,10 +62,17 @@ struct GeluFunctor { template struct GeluGradFunctor { __host__ __device__ T operator()(const T *args) const { - const T grad = args[0]; - const T x = args[1]; - math::GeluGradFunctor gelu_grad; - return grad * gelu_grad.UseOut(x); + using U = LayerNormParamType; + auto casted_x = static_cast(args[1]); + auto casted_dout = static_cast(args[0]); + + auto first = + static_cast(0.5) * + (static_cast(1) + erf(casted_x * static_cast(M_SQRT1_2))); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x * + exp(-static_cast(0.5) * casted_x * casted_x); + return static_cast(casted_dout * (first + second)); } }; @@ -139,7 +149,8 @@ __global__ void FusedDropoutActBias(Functor act, const uint64_t seed, #pragma unroll for (int ii = 0; ii < VecSize; ii++) { const T tmp = src_vec[ii] + bias_vec[ii]; - dest_vec[ii] = act(&tmp) * static_cast(mask_vec[ii]) * factor; + const T act_out = act(&tmp); + dest_vec[ii] = act_out * static_cast(mask_vec[ii]) * factor; } // store result to global platform::Store(dest_vec, &dst[r * cols + i]); diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index b3b5ad5ea4af63..3e4200a717d4eb 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -54,7 +54,6 @@ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( return config; } - template __forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, float *data); @@ -71,7 +70,7 @@ __forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state, data[0] = curand_uniform(state); data[1] = curand_uniform(state); } - + template <> __forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state, float *data) { From cabb9d2c11b7d9cf83395267469cb67d992583d5 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Fri, 10 Sep 2021 02:25:44 +0000 Subject: [PATCH 17/18] add #define _USE_MATH_DEFINES for windows --- .../operators/fused/fused_dropout_act_bias.h | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index a6351a3910f6cf..f164b22d54013c 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif #include "paddle/fluid/operators/fused/fused_dropout_common.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" @@ -28,7 +31,7 @@ typedef platform::float16 fp16; */ template struct ReluFunctor { - __host__ __device__ T operator()(const T *args) const { + inline __host__ __device__ T operator()(const T *args) const { math::ReluFunctor relu; return relu(args[0]); } @@ -36,7 +39,7 @@ struct ReluFunctor { template struct ReluGradFunctor { - __host__ __device__ __forceinline__ T operator()(const T *args) const { + inline __host__ __device__ T operator()(const T *args) const { math::ReluGradFunctor relu_grad; return args[0] * relu_grad.UseOut(args[1]); } @@ -47,11 +50,11 @@ struct ReluGradFunctor { */ template struct GeluFunctor { - __host__ __device__ T operator()(const T *args) const { + inline __host__ __device__ T operator()(const T *args) const { using U = LayerNormParamType; - U casted_x = static_cast(args[0]); - auto temp = erf(casted_x * static_cast(M_SQRT1_2)); - auto out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); + const U casted_x = static_cast(args[0]); + const U temp = erf(casted_x * static_cast(M_SQRT1_2)); + const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); return static_cast(out); } }; @@ -61,7 +64,7 @@ struct GeluFunctor { */ template struct GeluGradFunctor { - __host__ __device__ T operator()(const T *args) const { + inline __host__ __device__ T operator()(const T *args) const { using U = LayerNormParamType; auto casted_x = static_cast(args[1]); auto casted_dout = static_cast(args[0]); @@ -112,9 +115,8 @@ __global__ void FusedDropoutActBias(Functor act, const uint64_t seed, using MaskLoadT = platform::AlignedVector; using MaskStoreT = platform::AlignedVector; - const int tmp_cols = cols / VecSize * VecSize; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { - for (int i = col_id * VecSize; i < tmp_cols; + for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { LoadT src_vec; LoadT bias_vec; @@ -154,7 +156,9 @@ __global__ void FusedDropoutActBias(Functor act, const uint64_t seed, } // store result to global platform::Store(dest_vec, &dst[r * cols + i]); - platform::Store(mask_vec, &mask[r * cols + i]); + if (!is_test) { + platform::Store(mask_vec, &mask[r * cols + i]); + } } } } From 3cfdff8c59aff7d630998f6b07a986850e8035c2 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Mon, 13 Sep 2021 03:08:48 +0000 Subject: [PATCH 18/18] modify the code according to the review comment --- .../operators/fused/fused_dropout_act_bias.h | 90 ++++--------------- .../fused/fused_dropout_act_bias_test.cu | 77 +++++++--------- .../operators/fused/fused_dropout_common.h | 45 ++++++++++ .../operators/fused/fused_dropout_test.h | 19 ++++ .../fused/fused_residual_dropout_bias.h | 47 ++-------- .../fused/fused_residual_dropout_bias_test.cu | 27 ++---- 6 files changed, 133 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index f164b22d54013c..7d815bb8c39933 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -18,41 +18,19 @@ limitations under the License. */ #endif #include "paddle/fluid/operators/fused/fused_dropout_common.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/math/functors.h" namespace paddle { namespace operators { -typedef platform::float16 fp16; - -/** - *@brief the relu functor - */ -template -struct ReluFunctor { - inline __host__ __device__ T operator()(const T *args) const { - math::ReluFunctor relu; - return relu(args[0]); - } -}; - -template -struct ReluGradFunctor { - inline __host__ __device__ T operator()(const T *args) const { - math::ReluGradFunctor relu_grad; - return args[0] * relu_grad.UseOut(args[1]); - } -}; - /** *@brief the gelu functor */ template struct GeluFunctor { - inline __host__ __device__ T operator()(const T *args) const { + inline __host__ __device__ T operator()(const T x) const { using U = LayerNormParamType; - const U casted_x = static_cast(args[0]); + const U casted_x = static_cast(x); const U temp = erf(casted_x * static_cast(M_SQRT1_2)); const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); return static_cast(out); @@ -64,10 +42,9 @@ struct GeluFunctor { */ template struct GeluGradFunctor { - inline __host__ __device__ T operator()(const T *args) const { + inline __host__ __device__ T UseOut(const T x) const { using U = LayerNormParamType; - auto casted_x = static_cast(args[1]); - auto casted_dout = static_cast(args[0]); + auto casted_x = static_cast(x); auto first = static_cast(0.5) * @@ -75,7 +52,7 @@ struct GeluGradFunctor { auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x * exp(-static_cast(0.5) * casted_x * casted_x); - return static_cast(casted_dout * (first + second)); + return static_cast((first + second)); } }; @@ -85,13 +62,12 @@ struct GeluGradFunctor { * the bias shape is (1, cols) */ template -__global__ void FusedDropoutActBias(Functor act, const uint64_t seed, - const uint64_t rows, const uint64_t cols, - const int increment, - const float dropout_prob, - const bool is_upscale_in_train, - const bool is_test, const T *src, - const T *bias, T *dst, MaskType *mask) { +__global__ void FusedDropoutActBias( + Functor act, const uint64_t seed, const uint64_t rows, const uint64_t cols, + const int increment, const float dropout_prob, + const bool is_upscale_in_train, const bool is_test, + const T *__restrict__ src, const T *__restrict__ bias, T *dst, + MaskType *mask) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; @@ -151,7 +127,7 @@ __global__ void FusedDropoutActBias(Functor act, const uint64_t seed, #pragma unroll for (int ii = 0; ii < VecSize; ii++) { const T tmp = src_vec[ii] + bias_vec[ii]; - const T act_out = act(&tmp); + const T act_out = act(tmp); dest_vec[ii] = act_out * static_cast(mask_vec[ii]) * factor; } // store result to global @@ -176,10 +152,8 @@ void LaunchDropoutActBias(Functor act_functor, const uint64_t seed, const platform::CUDADeviceContext &ctx) { // dropout_prob == 1.0f if (std::abs(dropout_prob - 1.0f) < 1e-5) { - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemsetAsync(dst, 0, rows * cols * sizeof(T), ctx.stream())); - PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( - mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + SetZero(ctx, dst, rows * cols); + SetZero(ctx, mask_data, rows * cols); return; } @@ -226,7 +200,7 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, T args[2]; args[0] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; args[1] = src_vec[ii]; - dx_vec[ii] = act_grad(args); + dx_vec[ii] = args[0] * act_grad.UseOut(args[1]); } platform::Store(dx_vec, &dx[i]); } @@ -236,7 +210,7 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, * blocks(128 * 8) * 1. calculate the dx and reduce total rows to 128 rows * 2. save 128*8 temporary sum in 8*128 shared memory - * 3. reduce the sum of 128 rows data by 8*VecSize warps + * 3. reduce the sum of 128 cols data by 8*VecSize warps */ template @@ -272,7 +246,7 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, T args[2]; args[0] = dout_vec[i] * static_cast(mask_vec[i]) * factor; args[1] = src_vec[i] + bias_vec[i]; - val = act_grad(args); + val = args[0] * act_grad.UseOut(args[1]); dx_vec[i] = val; tmp_sum[i] += val; } @@ -280,35 +254,7 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, } } - __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; - for (int i = 0; i < VecSize; i++) { - cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; - } - __syncthreads(); - - // reduce sum - T sum = static_cast(0); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 5; // warp id - int y = tid & 31; // thread id on warp 0~31 - - // need BlockSizeX * VecSize warps - if (x < BlockSizeX * VecSize) { -// reduce 128 to 32 -#pragma unroll - for (int i = 0; i < (BlockSizeY >> 5); i++) { - sum += cache[x][y + i * 32]; - } - } - - // reduce 32 to 1 - sum = WarpReduceSum(sum); - - // save sum to dbias - int bias_id = blockIdx.x * blockDim.x * VecSize + x; - if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { - dbias[bias_id] = sum; - } + CalculateDBias(tmp_sum, dbias, cols); } /** diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu index f33e4d020b2954..0adbf0be4e28aa 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu @@ -20,11 +20,12 @@ limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" #include "paddle/fluid/operators/fused/fused_dropout_test.h" +#include "paddle/fluid/operators/math/functors.h" namespace framework = paddle::framework; namespace platform = paddle::platform; namespace details = paddle::operators::details; -namespace operators = paddle::operators; +namespace math = paddle::operators::math; /** * @brief the unittest of fused_dropout_act_bias @@ -111,16 +112,12 @@ struct TestFusedDropoutActBias { } { - out.Resize({rows, cols}); - out.mutable_data(place); - mask.Resize({rows, cols}); - mask.mutable_data(place); - dsrc.Resize({rows, cols}); - dsrc.mutable_data(place); + out.mutable_data({rows, cols}, place); + mask.mutable_data({rows, cols}, place); + dsrc.mutable_data({rows, cols}, place); if (has_bias) { - dbias.Resize({cols}); - dbias.mutable_data(place); + dbias.mutable_data({cols}, place); } } } @@ -133,7 +130,7 @@ struct TestFusedDropoutActBias { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { const T tmp = src_vec[i * cols + j] + bias_vec[j]; - out1[i * cols + j] = act(&tmp); + out1[i * cols + j] = act(tmp); } } // call dropout @@ -143,7 +140,7 @@ struct TestFusedDropoutActBias { for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { const T tmp = src_vec[i * cols + j]; - out1[i * cols + j] = act(&tmp); + out1[i * cols + j] = act(tmp); } } @@ -164,22 +161,22 @@ struct TestFusedDropoutActBias { GradFunctor act_grad; for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { + T args[2]; + args[0] = _out[i * cols + j]; if (has_bias) { - T args[2]; - args[0] = _out[i * cols + j]; args[1] = src_vec[i * cols + j] + bias_vec[j]; - T val = act_grad(args); - correct_dbias[j] += val; - correct_dsrc[i * cols + j] = val; } else { - T args[2]; - args[0] = _out[i * cols + j]; args[1] = src_vec[i * cols + j]; - T val = act_grad(args); - correct_dsrc[i * cols + j] = val; } + T val = args[0] * act_grad.UseOut(args[1]); + correct_dsrc[i * cols + j] = val; } } + + if (has_bias) { + // reduce_sum: keep the same calculate order as the GPU + ReduceSum(correct_dsrc, &correct_dbias, rows, cols); + } } void FusedForward() { @@ -273,47 +270,41 @@ static void BaseTest(const bool is_fp16 = false) { const int rows = 16; std::vector cols_list = {16, 17}; bool has_bias[2] = {true, false}; - T default_diff = !is_fp16 ? static_cast(1e-3) : default_diff = - static_cast(1e-2); + T default_diff = !is_fp16 ? static_cast(1e-5) : static_cast(1e-1); for (auto cols : {16, 17}) { for (auto has_bias : {true, false}) { TestFusedDropoutActBias test(rows, cols); test.has_bias = has_bias; test.Run(); test.CheckOut(default_diff); - if (!is_fp16) { - test.CheckGrad(default_diff); - } + test.CheckGrad(default_diff); } } } TEST(FusedDropout, GPUFusedDorpoutActBias) { - BaseTest, - paddle::operators::ReluGradFunctor>(); - BaseTest, - operators::GeluGradFunctor>(); + BaseTest, math::ReluGradFunctor>(); + BaseTest, + paddle::operators::GeluGradFunctor>(); } TEST(FusedDropout, GPUFusedDropoutActBiasDouble) { - BaseTest, - operators::ReluGradFunctor>(); - BaseTest, - operators::GeluGradFunctor>(); + BaseTest, math::ReluGradFunctor>(); + BaseTest, + paddle::operators::GeluGradFunctor>(); } // test fp16, For inference, check_grad is not required. ref: test_dropout_op.py TEST(FusedDropout, GPUFusedDropoutActBiasFp16) { using fp16 = platform::float16; - BaseTest, - operators::ReluGradFunctor>(true); + BaseTest, math::ReluGradFunctor>(true); } TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) { const int rows = 16; const int cols = 16; for (auto is_upscale_in_train : {true, false}) { - TestFusedDropoutActBias, - operators::ReluGradFunctor> + TestFusedDropoutActBias, + math::ReluGradFunctor> test(rows, cols, 0, 1.0, is_upscale_in_train, false); test.Run(); test.CheckOut(static_cast(1e-5)); @@ -324,8 +315,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) { TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) { const int rows = 16; const int cols = 16; - TestFusedDropoutActBias, - operators::ReluGradFunctor> + TestFusedDropoutActBias, + math::ReluGradFunctor> test(rows, cols, 0, 0.35, true, true); test.Run(); test.CheckOut(static_cast(1e-5)); @@ -335,8 +326,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) { TEST(FusedDropout, GPUFusedDropoutActBiasSeed) { const int rows = 16; const int cols = 16; - TestFusedDropoutActBias, - operators::ReluGradFunctor> + TestFusedDropoutActBias, + math::ReluGradFunctor> test(rows, cols, 125, 0.0, false, false); test.Run(); test.CheckOut(static_cast(1e-5)); @@ -346,8 +337,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasSeed) { TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) { const int rows = 256; const int cols = 4096; - TestFusedDropoutActBias, - operators::ReluGradFunctor> + TestFusedDropoutActBias, + math::ReluGradFunctor> test(rows, cols); test.Run(); test.CheckOut(static_cast(1e-5)); diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 3e4200a717d4eb..02c3a2d6f1a12f 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device_context.h" @@ -88,5 +89,49 @@ __forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state, RandVec<4>(state, data + 4); } +template +inline void SetZero(const platform::CUDADeviceContext &ctx, T *ptr, + const size_t size) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream())); +} + +/** + * reduce the sum of 128 cols data by 8*VecSize warps + */ +template +inline __device__ void CalculateDBias(const T *tmp_sum, T *dbias, + const int cols) { + // save temporary sum to cache and do transpose + __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; + for (int i = 0; i < VecSize; i++) { + cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; + } + __syncthreads(); + // reduce sum + T sum = static_cast(0); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 5; // warp id + int y = tid & 31; // thread id on warp 0~31 + + // need BlockSizeX * VecSize warps + if (x < BlockSizeX * VecSize) { +// reduce 128 to 32 +#pragma unroll + for (int i = 0; i < (BlockSizeY >> 5); i++) { + sum += cache[x][y + i * 32]; + } + } + + // reduce 32 to 1 + sum = WarpReduceSum(sum); + + // save sum to dbias + int bias_id = blockIdx.x * blockDim.x * VecSize + x; + if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { + dbias[bias_id] = sum; + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index 288b415aef31f9..eae2f5457b07f8 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -115,3 +115,22 @@ void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, framework::TensorToVector(*tensor_dx, ctx, dx); ctx.Wait(); } + +template +inline void ReduceSum(const std::vector &dout, std::vector *dbias, + const int rows, const int cols) { + for (int j = 0; j < cols; j++) { + std::vector tmp_dbias(rows); + for (int i = 0; i < rows; i++) { + tmp_dbias[i] = dout[i * cols + j]; + } + int tmp_rows = rows / 2; + while (tmp_rows) { + for (int i = 0; i < tmp_rows; i++) { + tmp_dbias[i] += tmp_dbias[i + tmp_rows]; + } + tmp_rows /= 2; + } + (*dbias)[j] = tmp_dbias[0]; + } +} diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 3ba91060a4e356..0230244c981555 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/fused_dropout_common.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" namespace paddle { namespace operators { @@ -28,8 +27,9 @@ template __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const int row_id, const int col_id, const int cols, curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, - const T *src, const T *residual, const T *bias, T *dst, MaskType *mask, - const bool is_test, typename details::MPTypeTrait::Type *mean_val, + const T *__restrict__ src, const T *__restrict__ residual, + const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test, + typename details::MPTypeTrait::Type *mean_val, typename details::MPTypeTrait::Type *var_val) { using LoadT = platform::AlignedVector; using StoreT = platform::AlignedVector; @@ -97,9 +97,10 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( template __global__ void FusedResidualDropoutBias( const size_t rows, const size_t cols, uint64_t seed, - const float dropout_prob, const bool is_upscale_in_train, const T *src, - const T *residual, const T *bias, MaskType *mask, T *dst, - uint64_t increment, const bool is_test) { + const float dropout_prob, const bool is_upscale_in_train, + const T *__restrict__ src, const T *__restrict__ residual, + const T *__restrict__ bias, MaskType *mask, T *dst, uint64_t increment, + const bool is_test) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; @@ -140,8 +141,7 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), ctx.stream()); if (!is_test) { - PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( - mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + SetZero(ctx, mask_data, rows * cols); } return; } @@ -230,36 +230,7 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, } } - // save temporary sum to cache and do transpose - __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; - for (int i = 0; i < VecSize; i++) { - cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; - } - __syncthreads(); - - // reduce sum - T sum = static_cast(0); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 5; // warp id - int y = tid & 31; // thread id on warp 0~31 - - // need BlockSizeX * VecSize warps - if (x < BlockSizeX * VecSize) { -// reduce 128 to 32 -#pragma unroll - for (int i = 0; i < (BlockSizeY >> 5); i++) { - sum += cache[x][y + i * 32]; - } - } - - // reduce 32 to 1 - sum = WarpReduceSum(sum); - - // save sum to dbias - int bias_id = blockIdx.x * blockDim.x * VecSize + x; - if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { - dbias[bias_id] = sum; - } + CalculateDBias(tmp_sum, dbias, cols); } /** diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index 47546d5f9ca13a..d44df536bdd10c 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -114,16 +114,12 @@ struct TestFusedResidualDropoutBias { } { - out.Resize({rows, cols}); - out.mutable_data(place); - mask.Resize({rows, cols}); - mask.mutable_data(place); - dsrc.Resize({rows, cols}); - dsrc.mutable_data(place); + out.mutable_data({rows, cols}, place); + mask.mutable_data({rows, cols}, place); + dsrc.mutable_data({rows, cols}, place); if (has_bias) { - dbias.Resize({cols}); - dbias.mutable_data(place); + dbias.mutable_data({cols}, place); } } } @@ -159,10 +155,8 @@ struct TestFusedResidualDropoutBias { dropout_prob, is_upscale_in_train); // calc dbias memset(&correct_dbias[0], 0, cols * sizeof(T)); - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { - correct_dbias[j] += correct_out[i * cols + j]; - } + if (has_bias) { + ReduceSum(correct_out, &correct_dbias, rows, cols); } } @@ -254,19 +248,14 @@ struct TestFusedResidualDropoutBias { template static void BaseTest(const bool is_fp16 = false) { const int rows = 16; - T default_diff = !is_fp16 ? static_cast(1e-5) : default_diff = - static_cast(1e-2); + T default_diff = !is_fp16 ? static_cast(1e-5) : static_cast(1e-1); for (auto cols : {16, 17}) { for (auto has_bias : {true, false}) { TestFusedResidualDropoutBias test(rows, cols); test.has_bias = has_bias; test.Run(); test.CheckOut(default_diff); - if (!is_fp16) { - // test fp16, For inference, check_grad is not required. ref: - // testdropout_op.py - test.CheckGrad(default_diff); - } + test.CheckGrad(default_diff); } } }