Skip to content

Commit a61c3fc

Browse files
committed
support multi-step draft-model with cudagraph
1 parent 7aea651 commit a61c3fc

File tree

2 files changed

+20
-49
lines changed

2 files changed

+20
-49
lines changed

fastdeploy/spec_decode/mtp.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
636636
self.model_inputs["not_need_stop"][0] = True
637637
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
638638

639-
def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
639+
def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0):
640640
"""
641641
Initialize forward meta and attention meta data
642642
"""
@@ -672,7 +672,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
672672
for attn_backend in self.attn_backends:
673673
attn_backend.init_attention_metadata(self.forward_meta)
674674

675-
self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph
675+
# Notes(liuzichang):
676+
# 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
677+
# 2. In multi-step execution, only the first step should be captured.
678+
self.forward_meta.step_use_cudagraph = (
679+
step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run)
680+
)
676681

677682
def exist_prefill(self):
678683
"""
@@ -774,7 +779,7 @@ def _post_process(self, sampled_token_ids):
774779
self.model_inputs["step_idx"],
775780
)
776781

777-
def _propose(self, step_use_cudagraph: bool = False):
782+
def _propose(self, step_use_cudagraph: bool = False, is_dummy_run=False):
778783
"""
779784
Main process for MTP inference.
780785
Args:
@@ -827,7 +832,9 @@ def _propose(self, step_use_cudagraph: bool = False):
827832
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
828833

829834
# Initialize forward meta data
830-
self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph)
835+
self._initialize_forward_meta(
836+
step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep
837+
)
831838
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
832839

833840
# Padding inputs for cuda graph
@@ -852,9 +859,10 @@ def _propose(self, step_use_cudagraph: bool = False):
852859
top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"],
853860
share_inputs=self.model_inputs,
854861
)
855-
862+
# Note(liuzichang):
863+
# paddle.clone would raise error 700 in cudaGraph mode
856864
if self.num_model_steps > 1:
857-
self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"])
865+
self.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False)
858866

859867
model_output = self.model(
860868
ids_remove_padding=self.model_inputs["ids_remove_padding"],
@@ -1017,10 +1025,12 @@ def _extend_draft_token_with_ngram_match(self):
10171025
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
10181026
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
10191027

1020-
def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False):
1028+
def _run_impl(
1029+
self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False
1030+
):
10211031
"""Execute Draft Model"""
10221032
self._prepare_inputs(full_hidden_states)
1023-
self._propose(step_use_cudagraph=step_use_cudagraph)
1033+
self._propose(step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run)
10241034
self._update_status()
10251035
if self.hybrid_mode:
10261036
self._extend_draft_token_with_ngram_match()

fastdeploy/worker/gpu_model_runner.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,51 +1950,12 @@ def capture_model(self) -> None:
19501950
),
19511951
batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)),
19521952
in_capturing=True,
1953-
expected_decode_len=self.speculative_config.num_speculative_tokens,
1953+
expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1,
19541954
accept_all_drafts=True,
19551955
)
19561956
logger.info(
1957-
f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
1957+
f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}"
19581958
)
1959-
if self.graph_opt_config.draft_model_use_cudagraph:
1960-
# Capture Draft Model without bsz 1
1961-
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
1962-
for batch_size in sorted(capture_sizes, reverse=True):
1963-
if batch_size == 1:
1964-
logger.info("Skip token_num = 1, when capture Draft model for mtp")
1965-
else:
1966-
assert batch_size % 2 == 0
1967-
self._dummy_run(
1968-
num_tokens=(
1969-
self.scheduler_config.max_num_seqs
1970-
if self.scheduler_config.splitwise_role == "decode"
1971-
else self.scheduler_config.max_num_batched_tokens
1972-
),
1973-
batch_size=int(batch_size / 2),
1974-
in_capturing=True,
1975-
expected_decode_len=3,
1976-
accept_all_drafts=True,
1977-
)
1978-
logger.info(
1979-
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
1980-
)
1981-
# Capture Draft Model with bsz 1
1982-
if 1 in capture_sizes:
1983-
self._dummy_run(
1984-
num_tokens=(
1985-
self.scheduler_config.max_num_seqs
1986-
if self.scheduler_config.splitwise_role == "decode"
1987-
else self.scheduler_config.max_num_batched_tokens
1988-
),
1989-
batch_size=int(1),
1990-
in_capturing=True,
1991-
expected_decode_len=3,
1992-
accept_all_drafts=False,
1993-
reject_all_drafts=True,
1994-
)
1995-
logger.info(
1996-
f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}"
1997-
)
19981959
else:
19991960
for batch_size in sorted(capture_sizes, reverse=True):
20001961
self._dummy_run(

0 commit comments

Comments
 (0)