-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Log_softmax forward case#1: axis=-1 #31630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
8f532b0
5b5804d
cee2470
5be3a45
a1d92b7
e674a5d
855d00b
20a37a8
82328a7
f6ece4d
0f56b5e
4d5533b
060953b
eb14185
302f08d
844b880
26e1850
66c48ae
f2a2f2e
ab96a80
bf320c7
c5404ce
c7d785e
480a52f
0c1aec6
24cd730
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,200 @@ | |||||||||||||||||||||||
| // See the License for the specific language governing permissions and | ||||||||||||||||||||||||
| // limitations under the License. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #include <cuda_runtime.h> | ||||||||||||||||||||||||
| #include <limits> | ||||||||||||||||||||||||
| #include "paddle/fluid/operators/log_softmax_op.h" | ||||||||||||||||||||||||
| #include "paddle/fluid/platform/cuda_device_function.h" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| namespace paddle { | ||||||||||||||||||||||||
| namespace operators { | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #define WARP_SIZE 32 | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) \ | ||||||||||||||||||||||||
| case L2E: \ | ||||||||||||||||||||||||
| WarpLogSoftmaxForward<T, double, L2E><<<blocks, threads, 0>>>( \ | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| dst, src, batch_count, softmax_elements_stride, softmax_elements); \ | ||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int LogTwoCeil(int value) { | ||||||||||||||||||||||||
| int log2_value = 0; | ||||||||||||||||||||||||
| while ((1 << log2_value) < value) ++log2_value; | ||||||||||||||||||||||||
| return log2_value; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T, int NumBatch, int KernelWarpSize> | ||||||||||||||||||||||||
| __device__ __forceinline__ void ReduceSumForWarpBatch(T* sum) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < NumBatch; ++i) { | ||||||||||||||||||||||||
| T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); | ||||||||||||||||||||||||
| sum[i] = sum[i] + sum_val; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T, int NumBatch, int KernelWarpSize> | ||||||||||||||||||||||||
| __device__ __forceinline__ void ReduceMaxForWarpBatch(T* sum) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < NumBatch; ++i) { | ||||||||||||||||||||||||
| T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); | ||||||||||||||||||||||||
| sum[i] = max(sum[i], max_val); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T, typename AccT, int log2_elements> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| __global__ void WarpLogSoftmaxForward(T* dst, const T* src, int batch_size, | ||||||||||||||||||||||||
| int stride, int element_count) { | ||||||||||||||||||||||||
| constexpr int next_power_of_two = 1 << log2_elements; | ||||||||||||||||||||||||
| constexpr int kernel_warp_size = | ||||||||||||||||||||||||
| (next_power_of_two < WARP_SIZE) ? next_power_of_two : WARP_SIZE; | ||||||||||||||||||||||||
| constexpr int warp_iterations = next_power_of_two / kernel_warp_size; | ||||||||||||||||||||||||
| constexpr int num_batch = (next_power_of_two <= 128) ? 2 : 1; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * num_batch; | ||||||||||||||||||||||||
| int local_batches = batch_size - first_batch; | ||||||||||||||||||||||||
| if (local_batches > num_batch) local_batches = num_batch; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| int local_idx = threadIdx.x; | ||||||||||||||||||||||||
| src += first_batch * stride + local_idx; | ||||||||||||||||||||||||
| dst += first_batch * stride + local_idx; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // 1.load data from global memory | ||||||||||||||||||||||||
| AccT elements[num_batch][warp_iterations]; | ||||||||||||||||||||||||
| int idx = threadIdx.x + blockDim.x * threadIdx.y; | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| for (int i = 0; i < num_batch; ++i) { | ||||||||||||||||||||||||
| int batch_element_count = (i >= local_batches) ? 0 : element_count; | ||||||||||||||||||||||||
| for (int it = 0; it < warp_iterations; ++it) { | ||||||||||||||||||||||||
| int element_index = local_idx + it * kernel_warp_size; | ||||||||||||||||||||||||
| if (element_index < batch_element_count) { | ||||||||||||||||||||||||
| elements[i][it] = | ||||||||||||||||||||||||
| static_cast<double>(src[i * element_count + it * kernel_warp_size]); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||
| elements[i][it] = -std::numeric_limits<AccT>::infinity(); | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // 2.compute max_value | ||||||||||||||||||||||||
| AccT max_value[num_batch]; | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < num_batch; ++i) { | ||||||||||||||||||||||||
| max_value[i] = elements[i][0]; | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int it = 1; it < warp_iterations; ++it) { | ||||||||||||||||||||||||
| max_value[i] = | ||||||||||||||||||||||||
| (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| ReduceMaxForWarpBatch<AccT, num_batch, kernel_warp_size>(max_value); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| AccT sum[num_batch]{0.0f}; | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < num_batch; ++i) { | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int it = 0; it < warp_iterations; ++it) { | ||||||||||||||||||||||||
| sum[i] += std::exp(elements[i][it] - max_value[i]); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| ReduceSumForWarpBatch<AccT, num_batch, kernel_warp_size>(sum); | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| // 3.store result | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int i = 0; i < num_batch; ++i) { | ||||||||||||||||||||||||
| if (i >= local_batches) break; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| sum[i] = std::log(sum[i]); | ||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||
| for (int it = 0; it < warp_iterations; ++it) { | ||||||||||||||||||||||||
| int element_index = local_idx + it * kernel_warp_size; | ||||||||||||||||||||||||
| if (element_index < element_count) { | ||||||||||||||||||||||||
| dst[i * element_count + it * kernel_warp_size] = | ||||||||||||||||||||||||
| elements[i][it] - max_value[i] - sum[i]; | ||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
| } | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| template <typename T> | |
| class MPTypeTrait { | |
| public: | |
| using Type = T; | |
| }; | |
| template <> | |
| class MPTypeTrait<platform::float16> { | |
| public: | |
| using Type = float; | |
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. 谢谢提供的解决方案!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量名命名:axx_bxx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量名都改为了这种形式。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
X、Out还没改。
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
191和192可以合成1行。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if里面为什么要加&& dim_size * sizeof(T) <= 4096这个判断呢?不支持double吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
支持double。当把&& dim_size * sizeof(T) <= 4096删去,可以正确执行,但是一致性的diff从0.0 变为1.0728e-6(atol=1.00e-6)。
&& dim_size <= 1024是必要的。
当outer_size=128,dim_size=1024时,有config<<<32, (32, 4)>>>,warp_iter=32,正确执行。
当outer_size=128,dim_size=1025时,有config<<<32, (32, 4)>>>,warp_iter=64,不能得到结果。
warp_iter表示一个thread使用到的寄存器,应该是warp_iter=64超过硬件限制了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HIP上会找不到cuda_runtime.h,可以试试看删掉这个头文件应该也可以运行,或者写成
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.