-
Notifications
You must be signed in to change notification settings - Fork 689
[Cherry-Pick][CI]Support multi-step mtp with cudagraph(#5886) #5898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -708,7 +708,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: | |
| self.model_inputs["not_need_stop"][0] = True | ||
| self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer | ||
|
|
||
| def _initialize_forward_meta(self, step_use_cudagraph: bool = False): | ||
| def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_run: bool = False, substep: int = 0): | ||
| """ | ||
| Initialize forward meta and attention meta data | ||
| """ | ||
|
|
@@ -744,7 +744,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False): | |
| for attn_backend in self.attn_backends: | ||
| attn_backend.init_attention_metadata(self.forward_meta) | ||
|
|
||
| self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.draft_model_use_cudagraph | ||
| # Notes(liuzichang): | ||
| # 1. CUDA Graph capture sizes must be recorded in descending order (large → small). | ||
| # 2. In multi-step execution, only the first step should be captured. | ||
| self.forward_meta.step_use_cudagraph = ( | ||
| step_use_cudagraph and self.draft_model_use_cudagraph and not (substep > 0 and is_dummy_run) | ||
| ) | ||
|
|
||
| def _initialize_forward_meta_xpu(self): | ||
|
|
||
|
|
@@ -929,7 +934,9 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F | |
| self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) | ||
|
|
||
| # Initialize forward meta data | ||
| self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph) | ||
| self._initialize_forward_meta( | ||
| step_use_cudagraph=step_use_cudagraph, is_dummy_run=is_dummy_run, substep=substep | ||
| ) | ||
| self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) | ||
|
|
||
| # Padding inputs for cuda graph | ||
|
|
@@ -954,9 +961,10 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F | |
| top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], | ||
| share_inputs=self.model_inputs, | ||
| ) | ||
|
|
||
| # Note(liuzichang): | ||
| # paddle.clone would raise error 700 in cudaGraph mode | ||
| if self.num_model_steps > 1: | ||
| self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) | ||
| self.last_seq_lens_this_time.copy_(self.model_inputs["seq_lens_this_time"], False) | ||
|
Comment on lines
966
to
+967
|
||
|
|
||
| model_output = self.model( | ||
| ids_remove_padding=self.model_inputs["ids_remove_padding"], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2105,51 +2105,12 @@ def capture_model(self) -> None: | |
| ), | ||
| batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), | ||
| in_capturing=True, | ||
| expected_decode_len=self.speculative_config.num_speculative_tokens, | ||
| expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1, | ||
| accept_all_drafts=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" | ||
| f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" | ||
|
||
| ) | ||
| if self.graph_opt_config.draft_model_use_cudagraph: | ||
| # Capture Draft Model without bsz 1 | ||
| # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph | ||
| for batch_size in sorted(capture_sizes, reverse=True): | ||
| if batch_size == 1: | ||
| logger.info("Skip token_num = 1, when capture Draft model for mtp") | ||
| else: | ||
| assert batch_size % 2 == 0 | ||
| self._dummy_run( | ||
| num_tokens=( | ||
| self.scheduler_config.max_num_seqs | ||
| if self.scheduler_config.splitwise_role == "decode" | ||
| else self.scheduler_config.max_num_batched_tokens | ||
| ), | ||
| batch_size=int(batch_size / 2), | ||
| in_capturing=True, | ||
| expected_decode_len=3, | ||
| accept_all_drafts=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" | ||
| ) | ||
| # Capture Draft Model with bsz 1 | ||
| if 1 in capture_sizes: | ||
| self._dummy_run( | ||
| num_tokens=( | ||
| self.scheduler_config.max_num_seqs | ||
| if self.scheduler_config.splitwise_role == "decode" | ||
| else self.scheduler_config.max_num_batched_tokens | ||
| ), | ||
| batch_size=int(1), | ||
| in_capturing=True, | ||
| expected_decode_len=3, | ||
| accept_all_drafts=False, | ||
| reject_all_drafts=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" | ||
| ) | ||
| else: | ||
| for batch_size in sorted(capture_sizes, reverse=True): | ||
| self._dummy_run( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new parameters
is_dummy_runandsubstepadded to the function signature lack documentation in the docstring. The docstring should describe what these parameters represent, their expected types, and when they should be used to maintain consistency with Python documentation best practices.