-
Notifications
You must be signed in to change notification settings - Fork 0
Repurpose Scheduler Spec Dec metric for testing correctness #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,7 +48,8 @@ def main(): | |
| args = parser.parse_args() | ||
|
|
||
| model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
| eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" | ||
| # eagle_dir = "yuhuili/EAGLE-LLaMA3-Instruct-8B" | ||
| eagle_dir = "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B" | ||
|
|
||
| max_model_len = 2048 | ||
|
|
||
|
|
@@ -86,22 +87,36 @@ def main(): | |
|
|
||
| sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) | ||
|
|
||
| outputs = llm.generate(prompt_token_ids=prompt_ids, | ||
| outputs, scheduler_stats = llm.generate(prompt_token_ids=prompt_ids, | ||
| sampling_params=sampling_params) | ||
|
|
||
| # calculate the average number of accepted tokens per forward pass, +1 is | ||
| # to account for the token from the target model that's always going to be | ||
| # accepted | ||
| acceptance_counts = [0] * (args.num_spec_tokens + 1) | ||
| for output in outputs: | ||
| for step, count in enumerate( | ||
| output.metrics.spec_token_acceptance_counts): | ||
| acceptance_counts[step] += count | ||
|
|
||
| print("-" * 50) | ||
| print(f"mean acceptance length: \ | ||
| {sum(acceptance_counts) / acceptance_counts[0]:.2f}") | ||
| print("-" * 50) | ||
| # import pdb; pdb.set_trace() # REMOVE | ||
| if scheduler_stats is None: | ||
| acceptance_counts = [0] * (args.num_spec_tokens + 1) | ||
| for output in outputs: | ||
| for step, count in enumerate( | ||
| output.metrics.spec_token_acceptance_counts): | ||
| acceptance_counts[step] += count | ||
|
|
||
| print("-" * 50) | ||
| print(f"mean acceptance length: \ | ||
| {sum(acceptance_counts) / acceptance_counts[0]:.2f}") | ||
| print("-" * 50) | ||
| elif scheduler_stats.spec_decoding_stats is not None: | ||
| num_draft_tokens = scheduler_stats.spec_decoding_stats.num_draft_tokens | ||
| num_accepted_tokens = scheduler_stats.spec_decoding_stats.num_accepted_tokens | ||
| num_spec_proposal = num_draft_tokens / args.num_spec_tokens | ||
| mean_accepted_tokens = 1 + num_accepted_tokens / num_spec_proposal | ||
|
Comment on lines
+109
to
+112
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
mean_accepted_tokens = |
||
| print("-" * 50) | ||
| print(f"mean acceptance length: {mean_accepted_tokens:.2f}, \ | ||
| num_draft_tokens: {num_draft_tokens}, \ | ||
| num_accepted_tokens: {num_accepted_tokens} \ | ||
| num_spec_proposal: {num_spec_proposal}") | ||
| print("-" * 50) | ||
|
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,6 +112,8 @@ def __init__( | |
| self.encoder_cache_manager = EncoderCacheManager( | ||
| cache_size=encoder_cache_size) | ||
|
|
||
| self.spec_decoding_stats = SpecDecodingStats() | ||
|
|
||
| def schedule(self) -> SchedulerOutput: | ||
| # NOTE(woosuk) on the scheduling algorithm: | ||
| # There's no "decoding phase" nor "prefill phase" in the scheduler. | ||
|
|
@@ -568,7 +570,8 @@ def update_from_output( | |
|
|
||
| new_running: list[Request] = [] | ||
| outputs: list[EngineCoreOutput] = [] | ||
| spec_decoding_stats: Optional[SpecDecodingStats] = None | ||
| # spec_decoding_stats: Optional[SpecDecodingStats] = None | ||
| spec_decoding_stats = self.spec_decoding_stats | ||
|
Comment on lines
+573
to
+574
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cache the |
||
|
|
||
| # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below | ||
| # loop can be a performance bottleneck. We should do our best to avoid | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using sglang model so that the prev SGL bench is comparable: https://docs.google.com/document/d/18ETJLsnxR88Qq3VDk5Mq-Hb7vuE9o3VNZ-hhz-OqAXk/edit?usp=sharing