-
Notifications
You must be signed in to change notification settings - Fork 3k
[LLM INFER] Optimize fuse some kernels in postprocess #9201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yuanlehome
merged 10 commits into
PaddlePaddle:develop
from
gzy19990617:fuse_kernels_in_preprocess_postprocess
Nov 6, 2024
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a5fd26f
optimize fuse some kernels
gzy19990617 b9b4dc4
optimize fuse some kernels
gzy19990617 00c018d
fix top_p reject
gzy19990617 e8b751f
fix
gzy19990617 149df24
Merge branch 'develop' into fuse_kernels_in_preprocess_postprocess
gzy19990617 a489d73
ci
gzy19990617 7ed0de5
Merge branch 'develop' into fuse_kernels_in_preprocess_postprocess
gzy19990617 6b3f627
Merge branch 'develop' into fuse_kernels_in_preprocess_postprocess
gzy19990617 55aacac
fix review
gzy19990617 3e5afae
fix
gzy19990617 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,258 @@ | ||
| // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "helper.h" | ||
|
|
||
| template<typename T> | ||
| __global__ void set_preids_token_penalty_multi_scores_kernel(const bool *stop_flags, | ||
| int64_t *pre_ids, | ||
| const int64_t *input_ids, | ||
| const int *seq_lens_encoder, | ||
| const int *seq_lens_decoder, | ||
| const int64_t *step_idx, | ||
| const T *penalty_scores, | ||
| const T *frequency_score, | ||
| const T *presence_score, | ||
| const float *temperatures, | ||
| const int64_t *cur_len, | ||
| const int64_t *min_len, | ||
| const int64_t *eos_token_id, | ||
| const int64_t *bad_words_list, | ||
| int *repeat_times, | ||
| T *logits, | ||
| const int64_t bs, | ||
| const int64_t length, | ||
| const int64_t end_length, | ||
| const int64_t length_id, | ||
| const int64_t bad_words_length, | ||
| const int64_t length_input_ids) { | ||
| int bi = blockIdx.x; | ||
| T *logits_now = logits + bi * length; | ||
| int tid = threadIdx.x; | ||
|
|
||
| if (tid < bs && !stop_flags[tid]) { | ||
| int64_t *pre_ids_now = pre_ids + tid * length; | ||
| const int64_t *input_ids_now = input_ids + tid * length_input_ids; | ||
| const int seq_len_dec = seq_lens_decoder[tid]; | ||
| const int seq_len_enc = seq_lens_encoder[tid]; | ||
| if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped | ||
|
|
||
| const int step_idx_now = step_idx[bi]; | ||
| if (tid == 0 && step_idx_now >= 0) { | ||
| if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder | ||
| pre_ids_now[step_idx_now] = input_ids_now[seq_len_enc - 1]; | ||
| } else { // decoedr, get first token | ||
| pre_ids_now[step_idx_now] = input_ids_now[0]; | ||
| } | ||
| } | ||
| } | ||
| __syncthreads(); | ||
| // min_length process | ||
| if (bi < bs) { | ||
| if (cur_len[bi] < min_len[bi]) { | ||
| if (tid < end_length) { | ||
| logits_now[eos_token_id[tid]] = -1e10; | ||
| } | ||
| } | ||
| } | ||
| // update repeat_times | ||
| int *repeat_times_now = repeat_times + bi * length; | ||
| const int64_t *pre_ids_now = pre_ids + bi * length_id; | ||
| for (int i = tid; i < length_id; i += blockDim.x) { | ||
| int64_t id = pre_ids_now[i]; | ||
| if (id < 0) break; | ||
| atomicAdd(&repeat_times_now[id], 1); | ||
| } | ||
| __syncthreads(); | ||
| // penalty_scores process | ||
| float alpha = static_cast<float>(penalty_scores[bi]); | ||
| float beta = static_cast<float>(frequency_score[bi]); | ||
| float gamma = static_cast<float>(presence_score[bi]); | ||
| for (int i = tid; i < length; i += blockDim.x) { | ||
| int times = repeat_times_now[i]; | ||
| float logit_now = static_cast<float>(logits_now[i]); | ||
| if (times != 0) { | ||
| logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; | ||
| logit_now = logit_now - times * beta - gamma; | ||
| } | ||
| logits_now[i] = static_cast<T>(logit_now / temperatures[bi]); | ||
| } | ||
| __syncthreads(); | ||
| // bad_words process | ||
| for (int i = tid; i < bad_words_length; i += blockDim.x) { | ||
| const int64_t bad_words_token_id = bad_words_list[i]; | ||
| if (bad_words_token_id >= length || bad_words_token_id < 0) continue; | ||
| logits_now[bad_words_token_id] = -1e10; | ||
| } | ||
| } | ||
|
|
||
| template <paddle::DataType D> | ||
| void set_preids_token_penalty_multi_scores(const paddle::Tensor& pre_ids, | ||
| const paddle::Tensor& input_ids, | ||
| const paddle::Tensor& seq_lens_encoder, | ||
| const paddle::Tensor& seq_lens_decoder, | ||
| const paddle::Tensor& step_idx, | ||
| const paddle::Tensor& stop_flags, | ||
| const paddle::Tensor& logits, | ||
| const paddle::Tensor& penalty_scores, | ||
| const paddle::Tensor& frequency_score, | ||
| const paddle::Tensor& presence_score, | ||
| const paddle::Tensor& temperatures, | ||
| const paddle::Tensor& bad_tokens, | ||
| const paddle::Tensor& cur_len, | ||
| const paddle::Tensor& min_len, | ||
| const paddle::Tensor& eos_token_id) { | ||
|
|
||
| typedef PDTraits<D> traits_; | ||
| typedef typename traits_::DataType DataType_; | ||
| typedef typename traits_::data_t data_t; | ||
| auto cu_stream = logits.stream(); | ||
| std::vector<int64_t> shape = logits.shape(); | ||
| auto repeat_times = paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); | ||
| int64_t bs = shape[0]; | ||
| int64_t length = shape[1]; | ||
| int64_t length_id = pre_ids.shape()[1]; | ||
| int64_t length_bad_words = bad_tokens.shape()[0]; | ||
| int64_t length_input_ids = input_ids.shape()[1]; | ||
|
|
||
| int64_t end_length = eos_token_id.shape()[0]; | ||
|
|
||
| set_preids_token_penalty_multi_scores_kernel<DataType_><<<bs, 1024, 0, cu_stream>>>( | ||
| stop_flags.data<bool>(), | ||
| const_cast<int64_t*>(pre_ids.data<int64_t>()), | ||
| input_ids.data<int64_t>(), | ||
| seq_lens_encoder.data<int>(), | ||
| seq_lens_decoder.data<int>(), | ||
| step_idx.data<int64_t>(), | ||
| reinterpret_cast<DataType_*>(const_cast<data_t*>(penalty_scores.data<data_t>())), | ||
| reinterpret_cast<DataType_*>(const_cast<data_t*>(frequency_score.data<data_t>())), | ||
| reinterpret_cast<DataType_*>(const_cast<data_t*>(presence_score.data<data_t>())), | ||
| temperatures.data<float>(), | ||
| cur_len.data<int64_t>(), | ||
| min_len.data<int64_t>(), | ||
| eos_token_id.data<int64_t>(), | ||
| bad_tokens.data<int64_t>(), | ||
| repeat_times.data<int>(), | ||
| reinterpret_cast<DataType_*>(const_cast<data_t*>(logits.data<data_t>())), | ||
| bs, | ||
| length, | ||
| end_length, | ||
| length_id, | ||
| length_bad_words, | ||
| length_input_ids | ||
| ); | ||
| } | ||
|
|
||
| void SetPreidsTokenPenaltyMultiScores(const paddle::Tensor& pre_ids, | ||
| const paddle::Tensor& input_ids, | ||
| const paddle::Tensor& seq_lens_encoder, | ||
| const paddle::Tensor& seq_lens_decoder, | ||
| const paddle::Tensor& step_idx, | ||
| const paddle::Tensor& stop_flags, | ||
| const paddle::Tensor& logits, | ||
| const paddle::Tensor& penalty_scores, | ||
| const paddle::Tensor& frequency_scores, | ||
| const paddle::Tensor& presence_scores, | ||
| const paddle::Tensor& temperatures, | ||
| const paddle::Tensor& bad_tokens, | ||
| const paddle::Tensor& cur_len, | ||
| const paddle::Tensor& min_len, | ||
| const paddle::Tensor& eos_token_id) { | ||
|
|
||
| switch (logits.type()) { | ||
| case paddle::DataType::BFLOAT16: { | ||
| return set_preids_token_penalty_multi_scores<paddle::DataType::BFLOAT16>( | ||
| pre_ids, | ||
| input_ids, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| step_idx, | ||
| stop_flags, | ||
| logits, | ||
| penalty_scores, | ||
| frequency_scores, | ||
| presence_scores, | ||
| temperatures, | ||
| bad_tokens, | ||
| cur_len, | ||
| min_len, | ||
| eos_token_id | ||
| ); | ||
| } | ||
| case paddle::DataType::FLOAT16: { | ||
| return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT16>( | ||
| pre_ids, | ||
| input_ids, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| step_idx, | ||
| stop_flags, | ||
| logits, | ||
| penalty_scores, | ||
| frequency_scores, | ||
| presence_scores, | ||
| temperatures, | ||
| bad_tokens, | ||
| cur_len, | ||
| min_len, | ||
| eos_token_id | ||
| ); | ||
| } | ||
| case paddle::DataType::FLOAT32: { | ||
| return set_preids_token_penalty_multi_scores<paddle::DataType::FLOAT32>( | ||
| pre_ids, | ||
| input_ids, | ||
| seq_lens_encoder, | ||
| seq_lens_decoder, | ||
| step_idx, | ||
| stop_flags, | ||
| logits, | ||
| penalty_scores, | ||
| frequency_scores, | ||
| presence_scores, | ||
| temperatures, | ||
| bad_tokens, | ||
| cur_len, | ||
| min_len, | ||
| eos_token_id | ||
| ); | ||
| } | ||
| default: { | ||
| PD_THROW( | ||
| "NOT supported data type. " | ||
| "Only float16, bfloat16 and float32 are supported. "); | ||
| break; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| PD_BUILD_OP(set_preids_token_penalty_multi_scores) | ||
| .Inputs({"pre_ids", | ||
| "input_ids", | ||
| "seq_lens_encoder", | ||
| "seq_lens_decoder", | ||
| "step_idx", | ||
| "stop_flags", | ||
| "logits", | ||
| "penalty_scores", | ||
| "frequency_scores", | ||
| "presence_scores", | ||
| "temperatures", | ||
| "bad_tokens", | ||
| "cur_len", | ||
| "min_len", | ||
| "eos_token_id"}) | ||
| .Outputs({"logits_out", "pre_ids_out"}) | ||
| .SetInplaceMap({{"logits", "logits_out"}, {"pre_ids", "pre_ids_out"}}) | ||
| .SetKernelFn(PD_KERNEL(SetPreidsTokenPenaltyMultiScores)); | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果这里固定写了-1e10,那TypeName应该只能限定Float32或者Bfloat16,而不能传Float16。但算子注册的时候全都注册了,这存在溢出的风险。虽然目前通过组网强制cast(Float32),但容易被用户用错。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以修改为,根据传入的类型设置不同精度的初始值?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得比较合理的情况是,输入不同的类型都兼容下;但如果简单处理,也可以只考虑注册特定的精度的算子