-
Notifications
You must be signed in to change notification settings - Fork 743
[RL] R3 Support Overlap Schedule #7674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
gongshaotian
wants to merge
7
commits into
PaddlePaddle:release/2.6
Choose a base branch
from
gongshaotian:r3_overlap_schedule_2.6
base: release/2.6
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
9547f3c
Correct the semantics of max_num_batched_tokens with multi mode
gongshaotian 10d451c
fix D2H bug
gongshaotian ee9b1e0
Merge branch 'release/2.6' of https://github.com/PaddlePaddle/FastDep…
gongshaotian adb9597
R3 support Overlap Schedule
gongshaotian 2a22b15
Merge branch 'release/2.6' of https://github.com/PaddlePaddle/FastDep…
gongshaotian 2c554f9
merge release/2.6
gongshaotian bed58de
rewrite get_position_id kernel
gongshaotian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "helper.h" | ||
| #include "paddle/extension.h" | ||
|
|
||
| __global__ void GetPositionIdsKernel(const int* __restrict__ seq_lens_encoder, | ||
| const int* __restrict__ seq_lens_decoder, | ||
| const int* __restrict__ seq_lens_this_time, | ||
| int* __restrict__ position_ids, | ||
| const int bsz) { | ||
| int current_bid = threadIdx.x; | ||
| if (current_bid >= bsz) return; | ||
|
|
||
| // Caculate the offset of current batch in the position_ids buffer | ||
| int buffer_offset = 0; | ||
| for (int i = 0; i < current_bid; i++) { | ||
| buffer_offset += seq_lens_this_time[i]; | ||
| } | ||
|
|
||
| // Caculate the token offset in the current batch | ||
| int token_offset = seq_lens_decoder[current_bid]; | ||
| int token_num_this_batch = seq_lens_this_time[current_bid]; | ||
| if (token_num_this_batch == 0) return; | ||
|
|
||
| // Write position ids for current batch | ||
| #pragma unroll | ||
| for (int i = 0; i < token_num_this_batch; i++) { | ||
| position_ids[buffer_offset + i] = token_offset + i; | ||
| } | ||
| } | ||
|
|
||
| void GetPositionIds(const paddle::Tensor& seq_lens_encoder, | ||
| const paddle::Tensor& seq_lens_decoder, | ||
| const paddle::Tensor& seq_lens_this_time, | ||
| const paddle::Tensor& position_ids) { | ||
| const int bsz = seq_lens_this_time.shape()[0]; | ||
|
|
||
| GetPositionIdsKernel<<<1, bsz, 0, position_ids.stream()>>>( | ||
| seq_lens_encoder.data<int>(), | ||
| seq_lens_decoder.data<int>(), | ||
| seq_lens_this_time.data<int>(), | ||
| const_cast<int*>(position_ids.data<int>()), | ||
| bsz); | ||
| } | ||
|
|
||
| PD_BUILD_STATIC_OP(get_position_ids) | ||
| .Inputs({ | ||
| "seq_lens_encoder", | ||
| "seq_lens_decoder", | ||
| "seq_lens_this_time", | ||
| "position_ids", | ||
| }) | ||
| .Outputs({"position_ids_out"}) | ||
| .SetInplaceMap({{"position_ids", "position_ids_out"}}) | ||
| .SetKernelFn(PD_KERNEL(GetPositionIds)); | ||
79 changes: 0 additions & 79 deletions
79
custom_ops/gpu_ops/get_position_ids_and_mask_encoder_batch.cu
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议
seq_lens_encoder参数传入GetPositionIdsKernel但在函数体中从未使用(内核中所有偏移量均由seq_lens_this_time和seq_lens_decoder计算)。建议删除该参数,或在注释中说明保留原因(如 ABI/API 兼容性)。若保留,建议加
(void)seq_lens_encoder;以避免编译器警告。