@@ -23,7 +23,9 @@ template <typename T, CopyType Type>
2323__global__ void SequencePaddingKernel (
2424 T* dst, const T* src, const T* pad_value, bool is_constant_pad,
2525 const size_t * seq_offsets, const size_t seq_num, const size_t pad_seq_len,
26- const size_t step_width, bool norm_by_len, const PadLayout layout) {
26+ const size_t step_width, bool norm_by_len, bool norm_by_batchsize,
27+ bool norm_by_total_logits_len, int total_logits_len,
28+ const PadLayout layout) {
2729 size_t seq_idx = blockIdx .y ;
2830 size_t seq_len = seq_offsets[seq_idx + 1 ] - seq_offsets[seq_idx];
2931
@@ -38,7 +40,15 @@ __global__ void SequencePaddingKernel(
3840 src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);
3941
4042 if (step_idx < seq_len) {
41- float scale = norm_by_len ? (1 .0f / static_cast <float >(seq_len)) : 1 .0f ;
43+ float scale = 1 .0f ;
44+ if (norm_by_total_logits_len) {
45+ scale = 1 .0f / static_cast <float >(total_logits_len);
46+ } else if (norm_by_batchsize) {
47+ scale = 1 .0f / static_cast <float >(seq_num);
48+ } else if (norm_by_len) {
49+ scale = norm_by_len ? (1 .0f / static_cast <float >(seq_len)) : 1 .0f ;
50+ }
51+
4252 for (size_t i = threadIdx .x ; i < step_width; i += blockDim .x ) {
4353 dst_data[i] = scale * src_data[i];
4454 }
@@ -57,6 +67,8 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
5767 framework::LoDTensor* pad_tensor,
5868 const framework::LoDTensor& pad_value, int pad_seq_len = -1 ,
5969 int lod_level = 0 , bool norm_by_times = false ,
70+ bool norm_by_batchsize = false ,
71+ bool norm_by_total_logits_len = false ,
6072 const PadLayout layout = kBatchLengthWidth ) {
6173 auto seq_lod = seq_tensor.lod ();
6274 const auto seq_offsets = framework::ToAbsOffset (seq_lod)[lod_level];
@@ -107,7 +119,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
107119 SequencePaddingKernel<T, kSeqToPad ><<<grid, threads, 0 , context.stream()>>> (
108120 pad_data, seq_data, pad_value_data, pad_value.numel () == 1 ,
109121 seq_offsets.CUDAData (context.GetPlace ()), seq_num, pad_seq_len,
110- step_width, norm_by_times, layout);
122+ step_width, norm_by_times, false , false , 0 , layout);
111123 }
112124};
113125
@@ -118,6 +130,8 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
118130 const framework::LoDTensor& pad_tensor,
119131 framework::LoDTensor* seq_tensor, int pad_seq_len = -1 ,
120132 int lod_level = 0 , bool norm_by_times = false ,
133+ bool norm_by_batchsize = false ,
134+ bool norm_by_total_logits_len = false ,
121135 const PadLayout layout = kBatchLengthWidth ) {
122136 auto seq_offsets = framework::ToAbsOffset (seq_tensor->lod ())[lod_level];
123137 const auto & seq_tensor_dims = seq_tensor->dims ();
@@ -126,6 +140,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
126140 if (pad_seq_len == -1 ) {
127141 pad_seq_len = max_seq_len;
128142 }
143+ int total_logits_len = TotalSequenceLength (seq_offsets);
129144 int step_width = seq_tensor->numel () / seq_tensor_dims[0 ];
130145 int seq_num = seq_offsets.size () - 1 ;
131146
@@ -159,7 +174,8 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
159174 SequencePaddingKernel<T, kPadToSeq ><<<grid, threads, 0 , context.stream()>>> (
160175 seq_data, pad_data, nullptr , false ,
161176 seq_offsets.CUDAData (context.GetPlace ()), seq_num, pad_seq_len,
162- step_width, norm_by_times, layout);
177+ step_width, norm_by_times, norm_by_batchsize, norm_by_total_logits_len,
178+ total_logits_len, layout);
163179 }
164180};
165181
0 commit comments