Skip to content

Commit a0a518f

Browse files
committed
support multi-step draft-model with cudagraph
1 parent 7aea651 commit a0a518f

File tree

2 files changed

+15
-46
lines changed

2 files changed

+15
-46
lines changed

fastdeploy/spec_decode/mtp.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
636636
self.model_inputs["not_need_stop"][0] = True
637637
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
638638

639-
def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
639+
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
640640
"""
641641
Initialize forward meta and attention meta data
642642
"""
@@ -672,7 +672,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
672672
for attn_backend in self.attn_backends:
673673
attn_backend.init_attention_metadata(self.forward_meta)
674674

675-
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
675+
# Notes(liuzichang):
676+
# 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
677+
# 2. In multi-step execution, only the first step should be captured.
678+
self.forward_meta.step_use_cudagraph = (
679+
step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run)
680+
)
676681

677682
def exist_prefill(self):
678683
"""
@@ -827,7 +832,9 @@ def _propose(self, step_use_cudagraph: bool = False):
827832
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
828833

829834
# Initialize forward meta data
830-
self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph)
835+
self._initialize_forward_meta(
836+
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
837+
)
831838
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
832839

833840
# Padding inputs for cuda graph
@@ -852,9 +859,10 @@ def _propose(self, step_use_cudagraph: bool = False):
852859
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
853860
share_inputs=self.model_inputs,
854861
)
855-
862+
# Note(liuzichang):
863+
# paddle.clone would raise error 700 in cudaGraph mode
856864
if self.num_model_steps > 1:
857-
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
865+
self.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False)
858866

859867
model_output = self.model(
860868
ids_remove_padding=self.model_inputs["ids_remove_padding"],

fastdeploy/worker/gpu_model_runner.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,51 +1950,12 @@ def capture_model(self) -> None:
19501950
),
19511951
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
19521952
in_capturing=True,
1953-
expected_decode_len=self.speculative_config.num_speculative_tokens,
1953+
expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1,
19541954
accept_all_drafts=True,
19551955
)
19561956
logger.info(
1957-
f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
1957+
f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
19581958
)
1959-
if self.graph_opt_config.draft_model_use_cudagraph:
1960-
# Capture Draft Model without bsz 1
1961-
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
1962-
for batch_size in sorted(capture_sizes, reverse=True):
1963-
if batch_size == 1:
1964-
logger.info("Skip token_num = 1, when capture Draft model for mtp")
1965-
else:
1966-
assert batch_size % 2 == 0
1967-
self._dummy_run(
1968-
num_tokens=(
1969-
self.scheduler_config.max_num_seqs
1970-
if self.scheduler_config.splitwise_role == "decode"
1971-
else self.scheduler_config.max_num_batched_tokens
1972-
),
1973-
batch_size=int(batch_size / 2),
1974-
in_capturing=True,
1975-
expected_decode_len=3,
1976-
accept_all_drafts=True,
1977-
)
1978-
logger.info(
1979-
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
1980-
)
1981-
# Capture Draft Model with bsz 1
1982-
if 1 in capture_sizes:
1983-
self._dummy_run(
1984-
num_tokens=(
1985-
self.scheduler_config.max_num_seqs
1986-
if self.scheduler_config.splitwise_role == "decode"
1987-
else self.scheduler_config.max_num_batched_tokens
1988-
),
1989-
batch_size=int(1),
1990-
in_capturing=True,
1991-
expected_decode_len=3,
1992-
accept_all_drafts=False,
1993-
reject_all_drafts=True,
1994-
)
1995-
logger.info(
1996-
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
1997-
)
19981959
else:
19991960
for batch_size in sorted(capture_sizes, reverse=True):
20001961
self._dummy_run(

0 commit comments

Comments
 (0)