Skip to content
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
4e8f93d
more
fzyzcjy May 27, 2025
86e1736
more
fzyzcjy May 27, 2025
0dc0b6f
more
fzyzcjy May 27, 2025
180909e
more
fzyzcjy May 27, 2025
0e5e6ef
more
fzyzcjy May 27, 2025
7dc2c38
more
fzyzcjy May 27, 2025
171e76d
more
fzyzcjy May 27, 2025
d0ae6ea
fmt
fzyzcjy May 27, 2025
ec9c1b1
more
fzyzcjy May 27, 2025
b1c34f5
more
fzyzcjy May 27, 2025
2fad15c
more
fzyzcjy May 27, 2025
0021673
more
fzyzcjy May 27, 2025
91e64b2
more
fzyzcjy May 27, 2025
086fced
more
fzyzcjy May 27, 2025
20f5fe8
more
fzyzcjy May 27, 2025
81f8658
more
fzyzcjy May 27, 2025
3f21146
more
fzyzcjy May 27, 2025
c336dbe
more
fzyzcjy May 27, 2025
a884462
more
fzyzcjy May 27, 2025
3433cf4
more
fzyzcjy May 27, 2025
7b9fdfd
more
fzyzcjy May 27, 2025
b96e538
more
fzyzcjy May 27, 2025
d3ccef8
more
fzyzcjy May 27, 2025
3bbfce8
more
fzyzcjy May 27, 2025
f09e892
more
fzyzcjy May 27, 2025
a922bf9
more
fzyzcjy May 27, 2025
31767ba
more
fzyzcjy May 27, 2025
8225f4b
more
fzyzcjy May 27, 2025
0bee404
more
fzyzcjy May 27, 2025
492062b
more
fzyzcjy May 27, 2025
ced1a1c
more
fzyzcjy May 27, 2025
98a9603
more
fzyzcjy May 27, 2025
7adde6e
more
fzyzcjy May 27, 2025
6271e7e
more
fzyzcjy May 27, 2025
302fc35
fmt
fzyzcjy May 27, 2025
a1b2bd6
more
fzyzcjy May 27, 2025
193ab7d
more
fzyzcjy May 27, 2025
12d63fb
more
fzyzcjy May 27, 2025
9a94b9b
more
fzyzcjy May 27, 2025
a6a78a0
more
fzyzcjy May 27, 2025
1a7a201
more
fzyzcjy May 27, 2025
8845aa7
more
fzyzcjy May 27, 2025
80b2cca
fix
fzyzcjy May 27, 2025
fc7ca7a
more
fzyzcjy May 27, 2025
6268da7
more
fzyzcjy May 27, 2025
acbed64
more
fzyzcjy May 27, 2025
7e85afe
more
fzyzcjy May 27, 2025
9964c08
more
fzyzcjy May 27, 2025
c138e14
Revert "more"
fzyzcjy May 27, 2025
933aee6
more
fzyzcjy May 27, 2025
cfbd755
more
fzyzcjy May 27, 2025
3a0f1c6
more
fzyzcjy May 27, 2025
8fa970b
more
fzyzcjy May 27, 2025
0f90f35
fmt
fzyzcjy May 27, 2025
9a17949
more
fzyzcjy May 27, 2025
919e863
ci
fzyzcjy May 27, 2025
f6664d3
Merge branch 'main' into feat/tbo_padding
fzyzcjy May 27, 2025
f4df233
Merge branch 'main' into feat/tbo_padding
fzyzcjy May 27, 2025
64ceb3a
Merge branch 'main' into feat/tbo_padding
fzyzcjy May 28, 2025
789a621
more
fzyzcjy May 28, 2025
2a9c8ad
Merge branch 'feat/tbo_padding' of https://github.com/fzyzcjy/sglang …
fzyzcjy May 28, 2025
607a6e7
Merge branch 'main' into feat/tbo_padding
ch-wan May 28, 2025
846e532
Merge branch 'main' into feat/tbo_padding
ch-wan May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import (
TboCudaGraphRunnerUtils,
TboCudaGraphRunnerPlugin,
TboForwardBatchPreparer,
)
from sglang.srt.utils import (
Expand Down Expand Up @@ -256,6 +256,7 @@ def __init__(self, model_runner: ModelRunner):
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
self.tbo_plugin = TboCudaGraphRunnerPlugin()

# pipeline parallelism
if self.pp_size > 1:
Expand Down Expand Up @@ -481,12 +482,9 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
capture_hidden_mode=self.capture_hidden_mode,
lora_paths=lora_paths,
num_token_non_padded=self.num_token_non_padded,
tbo_split_seq_index=TboCudaGraphRunnerUtils.compute_tbo_split_seq_index(
self, num_tokens
),
global_forward_mode=self.capture_forward_mode,
)
TboForwardBatchPreparer.prepare(forward_batch)
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)

if lora_paths is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
Expand Down Expand Up @@ -582,6 +580,7 @@ def replay_prepare(
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
self.num_token_non_padded[...] = len(forward_batch.input_ids)
self.tbo_plugin.replay_prepare(forward_batch)
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def op_select_experts(self, state):
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
Expand Down
98 changes: 73 additions & 25 deletions python/sglang/srt/two_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,36 @@ def compute_split_token_index(
# -------------------------------- Preparation ---------------------------------------


class TboCudaGraphRunnerUtils:
@staticmethod
def compute_tbo_split_seq_index(that: "CudaGraphRunner", num_tokens: int):
if that.model_runner.server_args.enable_two_batch_overlap:
tbo_split_seq_index = compute_split_seq_index(
forward_mode=that.capture_forward_mode,
num_tokens=num_tokens,
extend_lens=None,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert (
tbo_split_seq_index is not None
), f"{that.capture_forward_mode=} {num_tokens=}"
else:
tbo_split_seq_index = None
return tbo_split_seq_index
class TboCudaGraphRunnerPlugin:
def __init__(self):
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)

def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]:
return

batch.tbo_split_seq_index = compute_split_seq_index(
forward_mode=batch.forward_mode,
num_tokens=num_tokens,
extend_lens=None,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message here f"{num_tokens=}" is a bit less informative than the original one in TboCudaGraphRunnerUtils (which was f"{that.capture_forward_mode=} {num_tokens=}").

To aid in debugging if this assertion fails, could we consider adding batch.forward_mode to this message? This would provide more context, similar to the previous version.

Suggested change
assert batch.tbo_split_seq_index is not None, f"{num_tokens=}"
assert batch.tbo_split_seq_index is not None, f"{batch.forward_mode=} {num_tokens=}"


self._fill_tensor_content(batch)

TboForwardBatchPreparer.prepare_raw(
batch,
tbo_children_num_token_non_padded=self._tbo_children_num_token_non_padded,
)

def replay_prepare(self, batch: ForwardBatch):
self._fill_tensor_content(batch)

def _fill_tensor_content(self, batch: ForwardBatch):
self._tbo_children_num_token_non_padded[...] = (
TboForwardBatchPreparer.compute_tbo_children_num_token_non_padded(batch)
)


class TboDPAttentionPreparer:
Expand Down Expand Up @@ -178,17 +192,24 @@ def _is_all_same(x):
class TboForwardBatchPreparer:
@classmethod
def prepare(cls, batch: ForwardBatch):
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend

if batch.tbo_split_seq_index is None:
return

tbo_split_token_index = compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu,
cls.prepare_raw(
batch,
tbo_children_num_token_non_padded=cls.compute_tbo_children_num_token_non_padded(
batch
),
)

@classmethod
def prepare_raw(
cls, batch: ForwardBatch, tbo_children_num_token_non_padded: torch.Tensor
):
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend

tbo_split_token_index = cls._compute_split_token_index(batch)

if _tbo_debug:
logger.info(
f"TboForwardBatchPreparer.prepare "
Expand All @@ -200,13 +221,18 @@ def prepare(cls, batch: ForwardBatch):
assert isinstance(batch.attn_backend, TboAttnBackend)
attn_backend_child_a, attn_backend_child_b = batch.attn_backend.children

[out_num_token_non_padded_a, out_num_token_non_padded_b] = (
tbo_children_num_token_non_padded
)

child_a = cls.filter_batch(
batch,
start_token_index=0,
end_token_index=tbo_split_token_index,
start_seq_index=0,
end_seq_index=batch.tbo_split_seq_index,
output_attn_backend=attn_backend_child_a,
out_num_token_non_padded=out_num_token_non_padded_a,
)
child_b = cls.filter_batch(
batch,
Expand All @@ -215,6 +241,7 @@ def prepare(cls, batch: ForwardBatch):
start_seq_index=batch.tbo_split_seq_index,
end_seq_index=batch.batch_size,
output_attn_backend=attn_backend_child_b,
out_num_token_non_padded=out_num_token_non_padded_b,
)

assert batch.tbo_children is None
Expand All @@ -230,9 +257,8 @@ def filter_batch(
start_seq_index: int,
end_seq_index: int,
output_attn_backend: AttentionBackend,
out_num_token_non_padded: torch.Tensor,
):
from sglang.srt.managers.schedule_batch import global_server_args_dict

num_tokens = batch.input_ids.shape[0]
num_seqs = batch.batch_size

Expand Down Expand Up @@ -313,6 +339,7 @@ def filter_batch(
),
extend_num_tokens=extend_num_tokens,
attn_backend=output_attn_backend,
num_token_non_padded=out_num_token_non_padded,
tbo_split_seq_index=None,
tbo_parent_token_range=(start_token_index, end_token_index),
tbo_children=None,
Expand All @@ -328,7 +355,6 @@ def filter_batch(
top_p_normalized_logprobs=False,
top_p=None,
mm_inputs=None,
num_token_non_padded=None,
)
)

Expand All @@ -343,6 +369,28 @@ def filter_batch(

return ForwardBatch(**output_dict)

@classmethod
def compute_tbo_children_num_token_non_padded(cls, batch: ForwardBatch):
tbo_split_token_index = cls._compute_split_token_index(batch)
num_token_non_padded = len(batch.input_ids)

# TODO we may make padding on both sub-batches to make it slightly more balanced
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

I notice the TODO comment: we may make padding on both sub-batches to make it slightly more balanced.

Could you clarify if this is planned as a follow-up improvement, or if it's considered out of scope for the current effort? Understanding the intent here would be helpful for context.

value_a = min(tbo_split_token_index, num_token_non_padded)
value_b = max(0, num_token_non_padded - tbo_split_token_index)
return torch.tensor(
[value_a, value_b],
device=global_server_args_dict["device"],
dtype=torch.int32,
)

@classmethod
def _compute_split_token_index(cls, batch: ForwardBatch):
return compute_split_token_index(
split_seq_index=batch.tbo_split_seq_index,
forward_mode=batch.forward_mode,
extend_seq_lens=batch.extend_seq_lens_cpu,
)


def _compute_extend_num_tokens(input_ids, forward_mode: ForwardMode):
if forward_mode.is_extend():
Expand Down
Loading