-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Log_softmax forward case#1: axis=-1 #31630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
8f532b0
5b5804d
cee2470
5be3a45
a1d92b7
e674a5d
855d00b
20a37a8
82328a7
f6ece4d
0f56b5e
4d5533b
060953b
eb14185
302f08d
844b880
26e1850
66c48ae
f2a2f2e
ab96a80
bf320c7
c5404ce
c7d785e
480a52f
0c1aec6
24cd730
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,4 @@ | ||||||||||||||||||||||||
| // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||||||||||||||||||||||||
| // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||||||||||||||||||||||||
| // | ||||||||||||||||||||||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||
| // you may not use this file except in compliance with the License. | ||||||||||||||||||||||||
|
|
@@ -12,15 +12,219 @@ | |||||||||||||||||||||||
| // See the License for the specific language governing permissions and | ||||||||||||||||||||||||
| // limitations under the License. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include <cuda_runtime.h> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| #include <cassert> | ||||||||||||||||||||||||
| #include <limits> | ||||||||||||||||||||||||
| #include "paddle/fluid/operators/log_softmax_op.h" | ||||||||||||||||||||||||
| #include "paddle/fluid/platform/cuda_device_function.h" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| namespace paddle { | ||||||||||||||||||||||||
| namespace operators { | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #define WARP_SIZE 32 | ||||||||||||||||||||||||
| int log2_ceil(int value); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \ | ||||||||||||||||||||||||
| case L2E: \ | ||||||||||||||||||||||||
| WarpLogSoftmaxForward<T, L2E><<<blocks, threads, 0>>>( \ | ||||||||||||||||||||||||
| dst, src, batch_count, softmax_elements_stride, softmax_elements); \ | ||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T, int WARP_BATCH, int WARP_SIZE_SOFTMAX> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| __device__ __forceinline__ void warp_reduce_sum(T* sum) { | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < WARP_BATCH; ++i) { | ||||||||||||||||||||||||
| T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); | ||||||||||||||||||||||||
| sum[i] = sum[i] + sum_val; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T, int WARP_BATCH, int WARP_SIZE_SOFTMAX> | ||||||||||||||||||||||||
| __device__ __forceinline__ void warp_reduce_max(T* sum) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < WARP_BATCH; ++i) { | ||||||||||||||||||||||||
| T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); | ||||||||||||||||||||||||
| sum[i] = max(sum[i], max_val); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T, int log2_elements> | ||||||||||||||||||||||||
| __global__ void WarpLogSoftmaxForward(T* dst, const T* src, int batch_size, | ||||||||||||||||||||||||
| int stride, int element_count) { | ||||||||||||||||||||||||
| constexpr int next_power_of_two = 1 << log2_elements; | ||||||||||||||||||||||||
| constexpr int KERNEL_WARP_SIZE = | ||||||||||||||||||||||||
| (next_power_of_two < WARP_SIZE) ? next_power_of_two : WARP_SIZE; | ||||||||||||||||||||||||
| constexpr int WARP_ITERATIONS = next_power_of_two / KERNEL_WARP_SIZE; | ||||||||||||||||||||||||
| constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; | ||||||||||||||||||||||||
| int local_batches = batch_size - first_batch; | ||||||||||||||||||||||||
| if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int local_idx = threadIdx.x; | ||||||||||||||||||||||||
| src += first_batch * stride + local_idx; | ||||||||||||||||||||||||
| dst += first_batch * stride + local_idx; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // 1.load data from global memory | ||||||||||||||||||||||||
| T elements[WARP_BATCH][WARP_ITERATIONS]; | ||||||||||||||||||||||||
| int idx = threadIdx.x + blockDim.x * threadIdx.y; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for (int i = 0; i < WARP_BATCH; ++i) { | ||||||||||||||||||||||||
| int batch_element_count = (i >= local_batches) ? 0 : element_count; | ||||||||||||||||||||||||
| for (int it = 0; it < WARP_ITERATIONS; ++it) { | ||||||||||||||||||||||||
| int element_index = local_idx + it * KERNEL_WARP_SIZE; | ||||||||||||||||||||||||
| if (element_index < batch_element_count) { | ||||||||||||||||||||||||
| elements[i][it] = src[i * element_count + it * KERNEL_WARP_SIZE]; | ||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||
| elements[i][it] = -std::numeric_limits<T>::infinity(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // 2.compute max_value | ||||||||||||||||||||||||
| T max_value[WARP_BATCH]; | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < WARP_BATCH; ++i) { | ||||||||||||||||||||||||
| max_value[i] = elements[i][0]; | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int it = 1; it < WARP_ITERATIONS; ++it) { | ||||||||||||||||||||||||
| max_value[i] = | ||||||||||||||||||||||||
| (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| warp_reduce_max<T, WARP_BATCH, KERNEL_WARP_SIZE>(max_value); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| T sum[WARP_BATCH]{0.0f}; | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < WARP_BATCH; ++i) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int it = 0; it < WARP_ITERATIONS; ++it) { | ||||||||||||||||||||||||
| sum[i] += std::exp(elements[i][it] - max_value[i]); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| warp_reduce_sum<T, WARP_BATCH, KERNEL_WARP_SIZE>(sum); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // 3.store result | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < WARP_BATCH; ++i) { | ||||||||||||||||||||||||
| if (i >= local_batches) break; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| sum[i] = std::log(sum[i]); | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int it = 0; it < WARP_ITERATIONS; ++it) { | ||||||||||||||||||||||||
| int element_index = local_idx + it * KERNEL_WARP_SIZE; | ||||||||||||||||||||||||
| if (element_index < element_count) { | ||||||||||||||||||||||||
| dst[i * element_count + it * KERNEL_WARP_SIZE] = | ||||||||||||||||||||||||
| elements[i][it] - max_value[i] - sum[i]; | ||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| template <typename T> | |
| class MPTypeTrait { | |
| public: | |
| using Type = T; | |
| }; | |
| template <> | |
| class MPTypeTrait<platform::float16> { | |
| public: | |
| using Type = float; | |
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 谢谢提供的解决方案!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数主要的功能是启动CUDA Kernel,所以可以叫LaunchLogSoftmaxForwardForLastAxis。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
检查用PADDLE_ENFORCE_XXX。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉这一层的封装没有必要。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CanonicalAxis已经对axis做了换算了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done。已删除
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SizeToAxis和SizeFromAxis可以分别计算outer_size和inner_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outer_size可以用SizeToAxis()得到;inner_size的计算与SizeFromAxis()有差别。这里应该调SizeOutAxis()。但是SizeOutAxis()定义在其他.cu文件中,在该文件中不能直接调用。(nvcc 没有开启 --relocatable-device-code=true --compile,开启后可以调用)。
所以保留inner_size,用SizeToAxis()获得outer_size。
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量名命名:axx_bxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量名都改为了这种形式。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X、Out还没改。
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么把float16类型去掉了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done,支持了float16。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件不是今年新增的,不用改copyright吧。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.