Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 19 additions & 27 deletions csrc/gpu/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,14 @@
// limitations under the License.

#include "paddle/extension.h"
#include "helper.h"

__global__ void RemovePaddingV2(int64_t *output_data,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;

for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
}

__global__ void GetPaddingOffsetKernelV2(int *padding_offset,
__global__ void GetPaddingOffsetV2Kernel(int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
int64_t *output_data,
const int64_t *input_data,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len) {
Expand All @@ -42,8 +30,15 @@ __global__ void GetPaddingOffsetKernelV2(int *padding_offset,
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
const int tgt_seq_id = bi * max_seq_len - cum_offset + i;
const int src_seq_id = bi * max_seq_len + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
if (ti == 0) {
if (bi == 0) {
cu_seqlens_q[0] = 0;
cu_seqlens_k[0] = 0;
}
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
Expand All @@ -64,24 +59,21 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);

const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full({token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::full({token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
GetPaddingOffsetKernelV2<<<bsz, 128, 0, cu_stream>>>(

auto x_remove_padding = GetEmptyTensor({token_num_data}, paddle::DataType::INT64, input_ids.place());
auto padding_offset = GetEmptyTensor({token_num_data}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, input_ids.place());

GetPaddingOffsetV2Kernel<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
RemovePaddingV2<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
cum_offsets_out.data<int>(),
seq_length);
return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
}
Expand Down
258 changes: 258 additions & 0 deletions csrc/gpu/set_preids_token_penalty_multi_scores.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"

template<typename T>
__global__ void set_preids_token_penalty_multi_scores_kernel(const bool *stop_flags,
int64_t *pre_ids,
const int64_t *input_ids,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int64_t *step_idx,
const T *penalty_scores,
const T *frequency_score,
const T *presence_score,
const float *temperatures,
const int64_t *cur_len,
const int64_t *min_len,
const int64_t *eos_token_id,
const int64_t *bad_words_list,
int *repeat_times,
T *logits,
const int64_t bs,
const int64_t length,
const int64_t end_length,
const int64_t length_id,
const int64_t bad_words_length,
const int64_t length_input_ids) {
int bi = blockIdx.x;
T *logits_now = logits + bi * length;
int tid = threadIdx.x;

if (tid < bs && !stop_flags[tid]) {
int64_t *pre_ids_now = pre_ids + tid * length;
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped

const int step_idx_now = step_idx[bi];
if (tid == 0 && step_idx_now >= 0) {
if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder
pre_ids_now[step_idx_now] = input_ids_now[seq_len_enc - 1];
} else { // decoedr, get first token
pre_ids_now[step_idx_now] = input_ids_now[0];
}
}
}
__syncthreads();
// min_length process
if (bi < bs) {
if (cur_len[bi] < min_len[bi]) {
if (tid < end_length) {
logits_now[eos_token_id[tid]] = -1e10;
}
}
}
// update repeat_times
int *repeat_times_now = repeat_times + bi * length;
const int64_t *pre_ids_now = pre_ids + bi * length_id;
for (int i = tid; i < length_id; i += blockDim.x) {
int64_t id = pre_ids_now[i];
if (id < 0) break;
atomicAdd(&repeat_times_now[id], 1);
}
__syncthreads();
// penalty_scores process
float alpha = static_cast<float>(penalty_scores[bi]);
float beta = static_cast<float>(frequency_score[bi]);
float gamma = static_cast<float>(presence_score[bi]);
for (int i = tid; i < length; i += blockDim.x) {
int times = repeat_times_now[i];
float logit_now = static_cast<float>(logits_now[i]);
if (times != 0) {
logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha;
logit_now = logit_now - times * beta - gamma;
}
logits_now[i] = static_cast<T>(logit_now / temperatures[bi]);
}
__syncthreads();
// bad_words process
for (int i = tid; i < bad_words_length; i += blockDim.x) {
const int64_t bad_words_token_id = bad_words_list[i];
if (bad_words_token_id >= length || bad_words_token_id < 0) continue;
logits_now[bad_words_token_id] = -1e10;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这里固定写了-1e10,那TypeName应该只能限定Float32或者Bfloat16,而不能传Float16。但算子注册的时候全都注册了,这存在溢出的风险。虽然目前通过组网强制cast(Float32),但容易被用户用错。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以修改为,根据传入的类型设置不同精度的初始值?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得比较合理的情况是,输入不同的类型都兼容下;但如果简单处理,也可以只考虑注册特定的精度的算子

}
}

template <paddle::DataType D>
void set_preids_token_penalty_multi_scores(const paddle::Tensor& pre_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_flags,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_score,
const paddle::Tensor& presence_score,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id) {

typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto cu_stream = logits.stream();
std::vector<int64_t> shape = logits.shape();
auto repeat_times = paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place());
int64_t bs = shape[0];
int64_t length = shape[1];
int64_t length_id = pre_ids.shape()[1];
int64_t length_bad_words = bad_tokens.shape()[0];
int64_t length_input_ids = input_ids.shape()[1];

int64_t end_length = eos_token_id.shape()[0];

set_preids_token_penalty_multi_scores_kernel<DataType_><<<bs, 1024, 0, cu_stream>>>(
stop_flags.data<bool>(),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
input_ids.data<int64_t>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
step_idx.data<int64_t>(),
reinterpret_cast<DataType_*>(const_cast<data_t*>(penalty_scores.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(frequency_score.data<data_t>())),
reinterpret_cast<DataType_*>(const_cast<data_t*>(presence_score.data<data_t>())),
temperatures.data<float>(),
cur_len.data<int64_t>(),
min_len.data<int64_t>(),
eos_token_id.data<int64_t>(),
bad_tokens.data<int64_t>(),
repeat_times.data<int>(),
reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())),
bs,
length,
end_length,
length_id,
length_bad_words,
length_input_ids
);
}

void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids,
const paddle::Tensor& input_ids,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_flags,
const paddle::Tensor& logits,
const paddle::Tensor& penalty_scores,
const paddle::Tensor& frequency_scores,
const paddle::Tensor& presence_scores,
const paddle::Tensor& temperatures,
const paddle::Tensor& bad_tokens,
const paddle::Tensor& cur_len,
const paddle::Tensor& min_len,
const paddle::Tensor& eos_token_id) {

switch (logits.type()) {
case paddle::DataType::BFLOAT16: {
return set_preids_token_penalty_multi_scores<paddle::DataType::BFLOAT16>(
pre_ids,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id
);
}
case paddle::DataType::FLOAT16: {
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT16>(
pre_ids,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id
);
}
case paddle::DataType::FLOAT32: {
return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT32>(
pre_ids,
input_ids,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
stop_flags,
logits,
penalty_scores,
frequency_scores,
presence_scores,
temperatures,
bad_tokens,
cur_len,
min_len,
eos_token_id
);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}

PD_BUILD_OP(set_preids_token_penalty_multi_scores)
.Inputs({"pre_ids",
"input_ids",
"seq_lens_encoder",
"seq_lens_decoder",
"step_idx",
"stop_flags",
"logits",
"penalty_scores",
"frequency_scores",
"presence_scores",
"temperatures",
"bad_tokens",
"cur_len",
"min_len",
"eos_token_id"})
.Outputs({"logits_out", "pre_ids_out"})
.SetInplaceMap({{"logits", "logits_out"}, {"pre_ids", "pre_ids_out"}})
.SetKernelFn(PD_KERNEL(SetPreidsTokenPenaltyMultiScores));
Loading