-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Speed up when having padding tokens two-batch overlap #6668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 42 commits
Commits
Show all changes
63 commits
Select commit
Hold shift + click to select a range
4e8f93d
more
fzyzcjy 86e1736
more
fzyzcjy 0dc0b6f
more
fzyzcjy 180909e
more
fzyzcjy 0e5e6ef
more
fzyzcjy 7dc2c38
more
fzyzcjy 171e76d
more
fzyzcjy d0ae6ea
fmt
fzyzcjy ec9c1b1
more
fzyzcjy b1c34f5
more
fzyzcjy 2fad15c
more
fzyzcjy 0021673
more
fzyzcjy 91e64b2
more
fzyzcjy 086fced
more
fzyzcjy 20f5fe8
more
fzyzcjy 81f8658
more
fzyzcjy 3f21146
more
fzyzcjy c336dbe
more
fzyzcjy a884462
more
fzyzcjy 3433cf4
more
fzyzcjy 7b9fdfd
more
fzyzcjy b96e538
more
fzyzcjy d3ccef8
more
fzyzcjy 3bbfce8
more
fzyzcjy f09e892
more
fzyzcjy a922bf9
more
fzyzcjy 31767ba
more
fzyzcjy 8225f4b
more
fzyzcjy 0bee404
more
fzyzcjy 492062b
more
fzyzcjy ced1a1c
more
fzyzcjy 98a9603
more
fzyzcjy 7adde6e
more
fzyzcjy 6271e7e
more
fzyzcjy 302fc35
fmt
fzyzcjy a1b2bd6
more
fzyzcjy 193ab7d
more
fzyzcjy 12d63fb
more
fzyzcjy 9a94b9b
more
fzyzcjy a6a78a0
more
fzyzcjy 1a7a201
more
fzyzcjy 8845aa7
more
fzyzcjy 80b2cca
fix
fzyzcjy fc7ca7a
more
fzyzcjy 6268da7
more
fzyzcjy acbed64
more
fzyzcjy 7e85afe
more
fzyzcjy 9964c08
more
fzyzcjy c138e14
Revert "more"
fzyzcjy 933aee6
more
fzyzcjy cfbd755
more
fzyzcjy 3a0f1c6
more
fzyzcjy 8fa970b
more
fzyzcjy 0f90f35
fmt
fzyzcjy 9a17949
more
fzyzcjy 919e863
ci
fzyzcjy f6664d3
Merge branch 'main' into feat/tbo_padding
fzyzcjy f4df233
Merge branch 'main' into feat/tbo_padding
fzyzcjy 64ceb3a
Merge branch 'main' into feat/tbo_padding
fzyzcjy 789a621
more
fzyzcjy 2a9c8ad
Merge branch 'feat/tbo_padding' of https://github.com/fzyzcjy/sglang …
fzyzcjy 607a6e7
Merge branch 'main' into feat/tbo_padding
ch-wan 846e532
Merge branch 'main' into feat/tbo_padding
ch-wan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,22 +88,36 @@ def compute_split_token_index( | |
| # -------------------------------- Preparation --------------------------------------- | ||
|
|
||
|
|
||
| class TboCudaGraphRunnerUtils: | ||
| @staticmethod | ||
| def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int): | ||
| if that.model_runner.server_args.enable_two_batch_overlap: | ||
| tbo_split_seq_index = compute_split_seq_index( | ||
| forward_mode=that.capture_forward_mode, | ||
| num_tokens=num_tokens, | ||
| extend_lens=None, | ||
| ) | ||
| # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true | ||
| assert ( | ||
| tbo_split_seq_index is not None | ||
| ), f"{that.capture_forward_mode=} {num_tokens=}" | ||
| else: | ||
| tbo_split_seq_index = None | ||
| return tbo_split_seq_index | ||
| class TboCudaGraphRunnerPlugin: | ||
| def __init__(self): | ||
| self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32) | ||
|
|
||
| def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): | ||
| if not global_server_args_dict["enable_two_batch_overlap"]: | ||
| return | ||
|
|
||
| batch.tbo_split_seq_index = compute_split_seq_index( | ||
| forward_mode=batch.forward_mode, | ||
| num_tokens=num_tokens, | ||
| extend_lens=None, | ||
| ) | ||
| # For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true | ||
| assert batch.tbo_split_seq_index is not None, f"{num_tokens=}" | ||
|
|
||
| self._fill_tensor_content(batch) | ||
|
|
||
| TboForwardBatchPreparer.prepare_raw( | ||
| batch, | ||
| tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded, | ||
| ) | ||
|
|
||
| def replay_prepare(self, batch: ForwardBatch): | ||
| self._fill_tensor_content(batch) | ||
|
|
||
| def _fill_tensor_content(self, batch: ForwardBatch): | ||
| self._tbo_children_num_token_non_padded[...] = ( | ||
| TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch) | ||
| ) | ||
|
|
||
|
|
||
| class TboDPAttentionPreparer: | ||
|
|
@@ -178,17 +192,24 @@ def _is_all_same(x): | |
| class TboForwardBatchPreparer: | ||
| @classmethod | ||
| def prepare(cls, batch: ForwardBatch): | ||
| from sglang.srt.layers.attention.tbo_backend import TboAttnBackend | ||
|
|
||
| if batch.tbo_split_seq_index is None: | ||
| return | ||
|
|
||
| tbo_split_token_index = compute_split_token_index( | ||
| split_seq_index=batch.tbo_split_seq_index, | ||
| forward_mode=batch.forward_mode, | ||
| extend_seq_lens=batch.extend_seq_lens_cpu, | ||
| cls.prepare_raw( | ||
| batch, | ||
| tbo_children_num_token_non_padded=cls.compute_tbo_children_num_token_non_padded( | ||
| batch | ||
| ), | ||
| ) | ||
|
|
||
| @classmethod | ||
| def prepare_raw( | ||
| cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor | ||
| ): | ||
| from sglang.srt.layers.attention.tbo_backend import TboAttnBackend | ||
|
|
||
| tbo_split_token_index = cls._compute_split_token_index(batch) | ||
|
|
||
| if _tbo_debug: | ||
| logger.info( | ||
| f"TboForwardBatchPreparer.prepare " | ||
|
|
@@ -200,13 +221,18 @@ def prepare(cls, batch: ForwardBatch): | |
| assert isinstance(batch.attn_backend, TboAttnBackend) | ||
| attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children | ||
|
|
||
| [out_num_token_non_padded_a, out_num_token_non_padded_b] = ( | ||
| tbo_children_num_token_non_padded | ||
| ) | ||
|
|
||
| child_a = cls.filter_batch( | ||
| batch, | ||
| start_token_index=0, | ||
| end_token_index=tbo_split_token_index, | ||
| start_seq_index=0, | ||
| end_seq_index=batch.tbo_split_seq_index, | ||
| output_attn_backend=attn_backend_child_a, | ||
| out_num_token_non_padded=out_num_token_non_padded_a, | ||
| ) | ||
| child_b = cls.filter_batch( | ||
| batch, | ||
|
|
@@ -215,6 +241,7 @@ def prepare(cls, batch: ForwardBatch): | |
| start_seq_index=batch.tbo_split_seq_index, | ||
| end_seq_index=batch.batch_size, | ||
| output_attn_backend=attn_backend_child_b, | ||
| out_num_token_non_padded=out_num_token_non_padded_b, | ||
| ) | ||
|
|
||
| assert batch.tbo_children is None | ||
|
|
@@ -230,9 +257,8 @@ def filter_batch( | |
| start_seq_index: int, | ||
| end_seq_index: int, | ||
| output_attn_backend: AttentionBackend, | ||
| out_num_token_non_padded: torch.Tensor, | ||
| ): | ||
| from sglang.srt.managers.schedule_batch import global_server_args_dict | ||
|
|
||
| num_tokens = batch.input_ids.shape[0] | ||
| num_seqs = batch.batch_size | ||
|
|
||
|
|
@@ -313,6 +339,7 @@ def filter_batch( | |
| ), | ||
| extend_num_tokens=extend_num_tokens, | ||
| attn_backend=output_attn_backend, | ||
| num_token_non_padded=out_num_token_non_padded, | ||
| tbo_split_seq_index=None, | ||
| tbo_parent_token_range=(start_token_index, end_token_index), | ||
| tbo_children=None, | ||
|
|
@@ -328,7 +355,6 @@ def filter_batch( | |
| top_p_normalized_logprobs=False, | ||
| top_p=None, | ||
| mm_inputs=None, | ||
| num_token_non_padded=None, | ||
| ) | ||
| ) | ||
|
|
||
|
|
@@ -343,6 +369,28 @@ def filter_batch( | |
|
|
||
| return ForwardBatch(**output_dict) | ||
|
|
||
| @classmethod | ||
| def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch): | ||
| tbo_split_token_index = cls._compute_split_token_index(batch) | ||
| num_token_non_padded = len(batch.input_ids) | ||
|
|
||
| # TODO we may make padding on both sub-batches to make it slightly more balanced | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| value_a = min(tbo_split_token_index, num_token_non_padded) | ||
| value_b = max(0, num_token_non_padded - tbo_split_token_index) | ||
| return torch.tensor( | ||
| [value_a, value_b], | ||
| device=global_server_args_dict["device"], | ||
| dtype=torch.int32, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def _compute_split_token_index(cls, batch: ForwardBatch): | ||
| return compute_split_token_index( | ||
| split_seq_index=batch.tbo_split_seq_index, | ||
| forward_mode=batch.forward_mode, | ||
| extend_seq_lens=batch.extend_seq_lens_cpu, | ||
| ) | ||
|
|
||
|
|
||
| def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode): | ||
| if forward_mode.is_extend(): | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion message here
f"{num_tokens=}"is a bit less informative than the original one inTboCudaGraphRunnerUtils(which wasf"{that.capture_forward_mode=} {num_tokens=}").To aid in debugging if this assertion fails, could we consider adding
batch.forward_modeto this message? This would provide more context, similar to the previous version.