Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
054ed81
feat: mtp support dp-attention with cuda-graph (#6080)
May 12, 2025
a602a29
fix dp+mtp bugs
May 27, 2025
ed6b060
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 May 28, 2025
6cc38e7
fix: MTP+cudagraph+DPAtten and fa3
TianQiLin666666 May 30, 2025
a526032
Merge remote-tracking branch 'github/main' into feature_mtp_support_d…
May 31, 2025
672d6be
feat:Enable CUDA Graph for draft_extend while supporting dp-attention…
May 31, 2025
b130867
fix: Adjust the init_cuda_graph_state and fixbug (#6081)
May 31, 2025
35fe3df
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 1, 2025
3ceedbe
Performance: Eliminate performance impact in non-dp-attention+mtp sce…
Jun 3, 2025
04ede24
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 3, 2025
96b7209
fix bugs for mtp (#6081)
Jun 4, 2025
54dd1f7
fix enable cuda graph for draft_extend stage while supporting dp-atte…
Jun 6, 2025
b01de94
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 Jun 6, 2025
5805662
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 6, 2025
990fe38
Merge branch 'main' into feature_mtp_support_dp_attention
Qiaolin-Yu Jun 7, 2025
658fd39
Added test cases for dp-attention + mtp (#6081)
Jun 7, 2025
8e47432
Merge commit '60fdad7cf343333e956a3889c12956396a1516bf' into u4lr451:…
Jun 9, 2025
57e8f1c
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 9, 2025
e15db54
Update mtp+dp-attention test cases (#6081)
Jun 9, 2025
64cc457
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 9, 2025
ed7d4e2
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 Jun 9, 2025
5cba657
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
b54f934
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
c336c53
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
76f6cde
compatibility for fa3 (#6081)
Jun 10, 2025
23f82db
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 10, 2025
9dff016
fix
Qiaolin-Yu Jun 11, 2025
cc124fb
fix
Qiaolin-Yu Jun 11, 2025
55aefb7
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 11, 2025
7d44df1
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 11, 2025
767ff45
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 11, 2025
4982404
Merge branch 'main' into feature_mtp_support_dp_attention
u4lr451 Jun 11, 2025
9be85b7
Remove redundant code (#6081)
Jun 11, 2025
6690410
Merge branch 'main' into feature_mtp_support_dp_attention
Qiaolin-Yu Jun 11, 2025
d4ec8c8
nit update
ch-wan Jun 12, 2025
1218312
nit fix (#6081)
Jun 12, 2025
42d2403
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 12, 2025
6f9478a
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 12, 2025
4e54751
Merge branch 'main' into feature_mtp_support_dp_attention
zhyncs Jun 13, 2025
ec987fc
update scheduler and eagle worker
ch-wan Jun 13, 2025
9c86afe
update eagle_worker (#6081)
Jun 13, 2025
b0cb235
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 13, 2025
973edde
update forward_batch_speculative_generation
Jun 14, 2025
4f299ae
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 14, 2025
d2a162f
Merge commit '55561e25533f195e6d6b11e1c3d2449bc9908495' into pr/u4lr4…
ch-wan Jun 15, 2025
6e7c69e
polish global sync
ch-wan Jun 15, 2025
37af1a2
refactor eagle_worker.py
ch-wan Jun 15, 2025
64cc292
fix
ch-wan Jun 15, 2025
5c6b93e
Merge branch 'main' into feature_mtp_support_dp_attention
ch-wan Jun 15, 2025
3744a72
Merge remote-tracking branch 'origin/HEAD' into pr/u4lr451/6081
ch-wan Jun 15, 2025
ab26c11
format
ch-wan Jun 15, 2025
c07ba77
fix refactor bug
Jun 15, 2025
ff07187
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 15, 2025
97f531b
fix enable_dp_lm_head when dp-size == tp-size
Jun 16, 2025
3f686b1
Performance: Support enabling CUDA graph when idle batches exist
Jun 16, 2025
5ae3c3d
Merge remote-tracking branch 'github/main' into u4lrssh.feature_mtp_s…
Jun 16, 2025
f3854ee
Merge remote-tracking branch 'github/main' into u4lr451:feature_mtp_s…
Jun 16, 2025
841defa
refine code for dp lm head
ch-wan Jun 16, 2025
2f64ad7
Merge branch 'main' into feature_mtp_support_dp_attention
zhyncs Jun 16, 2025
a279680
Revert "Performance: Support enabling CUDA graph when idle batches ex…
ch-wan Jun 16, 2025
038ca0f
add a note
ch-wan Jun 17, 2025
3bc16e4
Merge commit '873ae12cee348dcb579a4c7456d789ef4441f3bf' into pr/u4lr4…
ch-wan Jun 17, 2025
16f8a63
Merge branch 'main' into feature_mtp_support_dp_attention
zhyncs Jun 17, 2025
e4bf571
fix merge error
ch-wan Jun 17, 2025
3a5b9d5
clean code and add comments
ch-wan Jun 17, 2025
a2effc0
Merge branch 'main' into feature_mtp_support_dp_attention
zhyncs Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)

def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None:
Expand All @@ -338,7 +341,7 @@ def init_cuda_graph_state(

if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/attention/base_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
raise NotImplementedError()

def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Init the global shared states for cuda graph."""
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def forward_decode(

return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.

Args:
Expand Down Expand Up @@ -1999,9 +1999,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)

def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs)
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)

def init_forward_metadata_capture_cuda_graph(
self,
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)

def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
(max_num_tokens * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
Expand All @@ -285,7 +288,7 @@ def init_cuda_graph_state(

if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
Expand Down Expand Up @@ -1096,7 +1099,7 @@ def call_fn(i, forward_batch):

self.common_template(forward_batch, kv_indices, call_fn)

def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
Expand All @@ -1105,7 +1108,7 @@ def init_cuda_graph_state(self, max_bs: int):

for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)

def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)

def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
Expand Down Expand Up @@ -852,7 +855,7 @@ def call_fn(i, forward_batch):

self.common_template(forward_batch, kv_indices, call_fn)

def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
Expand All @@ -861,7 +864,7 @@ def init_cuda_graph_state(self, max_bs: int):

for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)

def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
Expand Down
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/attention/flashmla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
Expand Down Expand Up @@ -502,9 +503,11 @@ def call_fn(i, forward_batch):

self.common_template(forward_batch, call_fn)

def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)

def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/attention/tbo_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def init_forward_metadata(self, forward_batch: "ForwardBatch"):
if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child)

def init_cuda_graph_state(self, max_bs: int):
self.primary.init_cuda_graph_state(max_bs=max_bs)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs)
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)

def init_forward_metadata_capture_cuda_graph(
self,
Expand Down
30 changes: 19 additions & 11 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
num_kv_splits = None
attn_logits = None
attn_lse = None

elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
Expand Down Expand Up @@ -335,24 +336,27 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)

def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_attn_lse = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits),
(max_num_tokens, self.num_head, self.max_kv_splits),
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
)
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
Expand All @@ -361,23 +365,26 @@ def init_cuda_graph_state(

if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)

if self.sliding_window_size is not None and self.sliding_window_size > 0:
if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros(
(max_bs * self.sliding_window_size),
(max_num_tokens * self.sliding_window_size),
dtype=torch.int32,
device=self.device,
)
else:
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)

self.cuda_graph_window_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
(max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
)

def init_forward_metadata_capture_cuda_graph(
Expand Down Expand Up @@ -458,6 +465,7 @@ def init_forward_metadata_capture_cuda_graph(
)

custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
Expand Down Expand Up @@ -821,15 +829,15 @@ def call_fn(i, forward_batch):

self.common_template(forward_batch, kv_indices, call_fn)

def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)

def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ def _dp_gather(
assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)

memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
Expand Down Expand Up @@ -288,6 +292,10 @@ def dp_scatter(
assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)

memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
Expand Down
8 changes: 7 additions & 1 deletion python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None

Expand Down Expand Up @@ -1760,11 +1761,15 @@ def copy(self):
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
)

def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
f"#req={(len(self.reqs))})"
)

Expand Down Expand Up @@ -1833,6 +1838,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None

# Overlap event
launch_done: Optional[threading.Event] = None
Expand Down
36 changes: 33 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,29 @@ def check_memory(self):
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()

def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""

local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch

def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
Expand Down Expand Up @@ -1383,7 +1406,14 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
self.running_batch.merge_batch(self.last_batch)

new_batch = self.get_new_batch_prefill()
if new_batch is not None:

# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self.server_args.enable_dp_attention
and not self.spec_algorithm.is_none()
and self.coordinate_spec_dp_attn_batch(new_batch)
)
if new_batch is not None or any_new_batch:
# Run prefill first if possible
ret = new_batch
else:
Expand Down Expand Up @@ -1732,8 +1762,6 @@ def prepare_dp_attn_batch_raw(
num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
Expand Down Expand Up @@ -1809,13 +1837,15 @@ def prepare_dp_attn_batch_raw(
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
local_batch.is_extend_in_batch = any(is_extend_in_batch)
local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode

# Check forward mode for cuda graph
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph

# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch)

def get_idle_batch(self):
Expand Down
Loading
Loading