-
Notifications
You must be signed in to change notification settings - Fork 689
[Cherry-Pick][CI]Support multi-step mtp with cudagraph(#5886) #5897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1950,51 +1950,12 @@ def capture_model(self) -> None: | |
| ), | ||
| batch_size=int(capture_size / (self.speculative_config.num_speculative_tokens + 1)), | ||
| in_capturing=True, | ||
| expected_decode_len=self.speculative_config.num_speculative_tokens, | ||
| expected_decode_len=self.speculative_config.num_speculative_tokens * 2 + 1, | ||
| accept_all_drafts=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the Target model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" | ||
| f"Warm up the model with the num_tokens:{capture_size}, expected_decode_len:{self.speculative_config.num_speculative_tokens}" | ||
|
Comment on lines
1956
to
+1957
|
||
| ) | ||
| if self.graph_opt_config.draft_model_use_cudagraph: | ||
| # Capture Draft Model without bsz 1 | ||
| # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph | ||
| for batch_size in sorted(capture_sizes, reverse=True): | ||
| if batch_size == 1: | ||
| logger.info("Skip token_num = 1, when capture Draft model for mtp") | ||
| else: | ||
| assert batch_size % 2 == 0 | ||
| self._dummy_run( | ||
| num_tokens=( | ||
| self.scheduler_config.max_num_seqs | ||
| if self.scheduler_config.splitwise_role == "decode" | ||
| else self.scheduler_config.max_num_batched_tokens | ||
| ), | ||
| batch_size=int(batch_size / 2), | ||
| in_capturing=True, | ||
| expected_decode_len=3, | ||
| accept_all_drafts=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" | ||
| ) | ||
| # Capture Draft Model with bsz 1 | ||
| if 1 in capture_sizes: | ||
| self._dummy_run( | ||
| num_tokens=( | ||
| self.scheduler_config.max_num_seqs | ||
| if self.scheduler_config.splitwise_role == "decode" | ||
| else self.scheduler_config.max_num_batched_tokens | ||
| ), | ||
| batch_size=int(1), | ||
| in_capturing=True, | ||
| expected_decode_len=3, | ||
| accept_all_drafts=False, | ||
| reject_all_drafts=True, | ||
| ) | ||
| logger.info( | ||
| f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}" | ||
| ) | ||
| else: | ||
| for batch_size in sorted(capture_sizes, reverse=True): | ||
| self._dummy_run( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable 'is_dummy_run' is used but not defined in the _propose method. This will cause a NameError at runtime when _initialize_forward_meta is called. The _propose method signature only includes 'step_use_cudagraph' as a parameter, but 'is_dummy_run' is being passed to _initialize_forward_meta. You need to either add 'is_dummy_run' as a parameter to the _propose method or determine it from existing state/attributes.