Skip to content
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f532b0
Merge pull request #1 from PaddlePaddle/develop
AshburnLee Sep 8, 2020
5b5804d
Merge pull request #2 from PaddlePaddle/develop
AshburnLee Sep 17, 2020
cee2470
Merge pull request #3 from PaddlePaddle/develop
AshburnLee Sep 30, 2020
5be3a45
Merge pull request #4 from PaddlePaddle/develop
AshburnLee Oct 13, 2020
a1d92b7
Merge pull request #5 from PaddlePaddle/develop
AshburnLee Oct 20, 2020
e674a5d
Merge pull request #6 from PaddlePaddle/develop
AshburnLee Nov 15, 2020
855d00b
Merge pull request #7 from PaddlePaddle/develop
AshburnLee Nov 18, 2020
20a37a8
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Mar 15, 2021
82328a7
temporary PR for log_softmax
AshburnLee Mar 15, 2021
f6ece4d
Logsoftmax formard case#1: axis=-1
AshburnLee Mar 16, 2021
0f56b5e
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Mar 16, 2021
4d5533b
Changed copyright
AshburnLee Mar 16, 2021
060953b
Made modifications according to PR reviewers
AshburnLee Mar 17, 2021
eb14185
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Mar 17, 2021
302f08d
Dealt with unittest precision errors
AshburnLee Mar 18, 2021
844b880
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Mar 18, 2021
26e1850
change launch cinfigure and code style
AshburnLee Mar 23, 2021
66c48ae
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Mar 23, 2021
f2a2f2e
Removed header file cuda_runtime.h for HIP support
AshburnLee Mar 23, 2021
ab96a80
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Mar 23, 2021
bf320c7
Modified code according to review comments
AshburnLee Mar 24, 2021
c5404ce
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Mar 24, 2021
c7d785e
Reply to review comments
AshburnLee Apr 8, 2021
480a52f
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 8, 2021
0c1aec6
cudaStream_t -> gpuStream_t
AshburnLee Apr 9, 2021
24cd730
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions paddle/fluid/operators/log_softmax_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,177 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <limits>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"

namespace paddle {
namespace operators {

#define LAUNCH_WARP_FORWAR_COMPUTE(near_greater_power_of_two) \
case near_greater_power_of_two: \
ComputeLogSoftmaxForwardInWarp< \
T, AccT, near_greater_power_of_two><<<blocks, threads, 0, stream>>>( \
dst, src, outer_size, dim_size); \
break;

template <typename T, int KernelWarpSize>
__device__ __forceinline__ T WarpReduceSum(T value) {
#pragma unroll
for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) {
T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset);
value = value + sum_val;
}
return value;
}

template <typename T, int KernelWarpSize>
__device__ __forceinline__ T WarpReduceMax(T value) {
#pragma unroll
for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) {
T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset);
value = max(value, max_val);
}
return value;
}

int GetNearGreaterPowerOfTwo(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) {
++log2_value;
}
return 1 << log2_value;
}

template <typename T, typename AccT, int NearGreaterPowerOfTwo>
__global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src,
int batch_size,
int element_count) {
constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo;
constexpr int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;

// set effective_warp_id as 1 when warps do effective work,
// when warps do ineffective work, effective_warp_id remains unchanged.
int effective_warp_id = batch_size - batch_id;
if (effective_warp_id > 1) effective_warp_id = 1;

int thread_in_warp_idx = threadIdx.x;

// 1.read data from global memory to registers
AccT elements[warp_iter];
// set effective_element_count as the num of elements when warps do effective
// work
// set effective_element_count as 0, when warps do ineffective work
int effective_element_count = (effective_warp_id <= 0) ? 0 : element_count;
for (int it = 0; it < warp_iter; ++it) {
int element_index = thread_in_warp_idx + it * kernel_warp_size;
if (element_index < effective_element_count) {
elements[it] =
static_cast<AccT>(src[batch_id * element_count + element_index]);
} else {
elements[it] = -std::numeric_limits<AccT>::infinity();
}
}

// 2.compute max_value. For each thread, loop all registers to find max
AccT max_value = elements[0];
#pragma unroll
for (int it = 1; it < warp_iter; ++it) {
max_value = (max_value > elements[it]) ? max_value : elements[it];
}
max_value = WarpReduceMax<AccT, kernel_warp_size>(max_value);

// 3.For each warp, accumulate all thread registers
AccT sum = 0.0f;
#pragma unroll
for (int it = 0; it < warp_iter; ++it) {
sum += std::exp(elements[it] - max_value);
}
sum = WarpReduceSum<AccT, kernel_warp_size>(sum);

// 4.store result.
sum = std::log(sum);
#pragma unroll
for (int it = 0; it < warp_iter; ++it) {
int element_index = thread_in_warp_idx + it * kernel_warp_size;
if (element_index < element_count) {
dst[batch_id * element_count + element_index] =
static_cast<T>(elements[it] - max_value - sum);
} else {
break;
}
}
}

template <typename T, typename AccT>
void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size,
int outer_size, gpuStream_t stream) {
int threads_per_block = 128;
int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size);
int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
int warps_per_block = (threads_per_block / kernel_warp_size);
int blocks = (outer_size + warps_per_block - 1) / warps_per_block;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于输入[N, 32]和[N, 128],kernel_warp_size=32,warps_per_block=4,这2种情况都是一个线程block分成4组,每组线程(即一个warp)计算1个batch?

Copy link
Contributor Author

@AshburnLee AshburnLee Mar 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

每组线程确实计算1个batch。

当N确定,通过观察configure<<<blocks, threads>>>随dim_size变化的变化,可以发现:
当dim_size>16时,threads始终为(32,4),变化的是blocks,和warp_iter,batch是1

假设N=128,变量有以下变化

  • 对于输入[N, 32]:dim_size: 32, kernel_warp_size: 32, warp_iter: 1, warp_batch: 1, config<<<4, (32, 4)>>>, numElem: 512 numThreads: 512

  • 对于输入[N, 128]: dim_size: 128, kernel_warp_size: 32, warp_iter: 4, warp_batch: 1, config<<<4, (32, 4)>>>, numElem: 2048 numThreads: 2048

这里numThreads表示线程数及其循环次数

确认是计算1个batch。

dim3 threads(kernel_warp_size, warps_per_block, 1);

switch (near_greater_power_of_two) {
LAUNCH_WARP_FORWAR_COMPUTE(1);
LAUNCH_WARP_FORWAR_COMPUTE(2);
LAUNCH_WARP_FORWAR_COMPUTE(4); // dim_size: 3~4
LAUNCH_WARP_FORWAR_COMPUTE(8); // dim_size: 5~8
LAUNCH_WARP_FORWAR_COMPUTE(16); // dim_size: 9~16
LAUNCH_WARP_FORWAR_COMPUTE(32); // dim_size: 17~32
LAUNCH_WARP_FORWAR_COMPUTE(64); // dim_size: 33~64
LAUNCH_WARP_FORWAR_COMPUTE(128); // dim_size 65~128
LAUNCH_WARP_FORWAR_COMPUTE(256); // dim_size 129~256
LAUNCH_WARP_FORWAR_COMPUTE(512); // dim_size 257~512
LAUNCH_WARP_FORWAR_COMPUTE(1024); // dim_size 513~1024

default:
break;
}
}

template <typename T>
class LogSoftmaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;

public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *x = context.Input<framework::Tensor>("X");
auto *out = context.Output<framework::Tensor>("Out");
const auto *input_data = x->data<T>();
auto *output_data = out->mutable_data<T>(context.GetPlace());

const int rank = x->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);

int dim_size = x->dims()[axis];
int inner_size = 1;
for (int i = axis + 1; i < x->dims().size(); ++i) {
inner_size *= x->dims()[i];
}
int outer_size = SizeToAxis(axis, x->dims());
gpuStream_t stream = context.cuda_device_context().stream();

if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if里面为什么要加&& dim_size * sizeof(T) <= 4096这个判断呢?不支持double吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

支持double。当把&& dim_size * sizeof(T) <= 4096删去,可以正确执行,但是一致性的diff从0.0 变为1.0728e-6(atol=1.00e-6)。

&& dim_size <= 1024是必要的。

当outer_size=128,dim_size=1024时,有config<<<32, (32, 4)>>>,warp_iter=32,正确执行。
当outer_size=128,dim_size=1025时,有config<<<32, (32, 4)>>>,warp_iter=64,不能得到结果。

warp_iter表示一个thread使用到的寄存器,应该是warp_iter=64超过硬件限制了。

LaunchSoftmaxForwardForLastAxis<T, MPDType>(output_data, input_data,
dim_size, outer_size, stream);
} else {
LogSoftmaxFunctor<platform::CUDADeviceContext, T>()(
context.template device_context<platform::CUDADeviceContext>(), x,
out, axis);
}
}
};

} // operators
} // paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
Expand Down