Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def main():
args = parser.parse_args()

model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
# eagle_dir = "yuhuili/EAGLE-LLaMA3-Instruct-8B"
eagle_dir = "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B"
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


max_model_len = 2048

Expand Down Expand Up @@ -86,22 +87,36 @@ def main():

sampling_params = SamplingParams(temperature=args.temp, max_tokens=256)

outputs = llm.generate(prompt_token_ids=prompt_ids,
outputs, scheduler_stats = llm.generate(prompt_token_ids=prompt_ids,
sampling_params=sampling_params)

# calculate the average number of accepted tokens per forward pass, +1 is
# to account for the token from the target model that's always going to be
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(
output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count

print("-" * 50)
print(f"mean acceptance length: \
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
print("-" * 50)
# import pdb; pdb.set_trace() # REMOVE
if scheduler_stats is None:
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(
output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count

print("-" * 50)
print(f"mean acceptance length: \
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
print("-" * 50)
elif scheduler_stats.spec_decoding_stats is not None:
num_draft_tokens = scheduler_stats.spec_decoding_stats.num_draft_tokens
num_accepted_tokens = scheduler_stats.spec_decoding_stats.num_accepted_tokens
num_spec_proposal = num_draft_tokens / args.num_spec_tokens
mean_accepted_tokens = 1 + num_accepted_tokens / num_spec_proposal
Comment on lines +109 to +112
Copy link
Owner Author

@ekagra-ranjan ekagra-ranjan Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_spec_proposal is the num of times the SD call was made

mean_accepted_tokens = (sum of generated tokens over num_spec_proposal) / num_spec_proposal
= (num_spec_proposal + sum of accepted tokens over num_spec_proposal) / num_spec_proposal
= 1 + num_accepted_tokens / num_spec_proposal

print("-" * 50)
print(f"mean acceptance length: {mean_accepted_tokens:.2f}, \
num_draft_tokens: {num_draft_tokens}, \
num_accepted_tokens: {num_accepted_tokens} \
num_spec_proposal: {num_spec_proposal}")
print("-" * 50)



if __name__ == "__main__":
Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Literal, Mapping, NamedTuple, Optional)
Iterable, List, Literal, Mapping, NamedTuple, Optional, Tuple)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload

Expand Down Expand Up @@ -1289,7 +1289,11 @@ def _advance_to_next_step(
else:
seq.append_token_id(sample.output_token, sample.logprobs)

def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
def step(self) -> Tuple[List[Union[RequestOutput, PoolingRequestOutput]],
# for compatibility with V1
# step API which return Scheduler stat
None
]:
"""Performs one decoding iteration and returns newly generated results.

.. figure:: https://i.imgur.com/sv2HssD.png
Expand Down Expand Up @@ -1516,7 +1520,7 @@ def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()

return ctx.request_outputs
return ctx.request_outputs, None

def _abort_and_cache_schedule(
self, request_id: str, virtual_engine: int,
Expand Down
19 changes: 10 additions & 9 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload, Tuple

import cloudpickle
import torch.nn as nn
Expand Down Expand Up @@ -44,6 +44,7 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)
from vllm.v1.metrics.stats import SchedulerStats

logger = init_logger(__name__)

Expand Down Expand Up @@ -389,7 +390,7 @@ def generate(
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None,
priority: Optional[list[int]] = None,
) -> list[RequestOutput]:
) -> Tuple[list[RequestOutput], SchedulerStats]:
"""Generates the completions for the input prompts.

This class automatically batches the given prompts, considering
Expand Down Expand Up @@ -467,8 +468,8 @@ def generate(
guided_options=guided_options_request,
priority=priority)

outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
outputs, scheduler_stats = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput), scheduler_stats

def collective_rpc(self,
method: Union[str, Callable[..., _R]],
Expand Down Expand Up @@ -929,7 +930,7 @@ def encode(
prompt_adapter_request=prompt_adapter_request,
)

outputs = self._run_engine(use_tqdm=use_tqdm)
outputs, _ = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)

Expand Down Expand Up @@ -1093,7 +1094,7 @@ def _cross_encoding_score(
prompt_adapter_request=prompt_adapter_request,
)

outputs = self._run_engine(use_tqdm=use_tqdm)
outputs, _ = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)

Expand Down Expand Up @@ -1379,7 +1380,7 @@ def _add_guided_params(

def _run_engine(
self, *, use_tqdm: bool
) -> list[Union[RequestOutput, PoolingRequestOutput]]:
) -> Tuple[list[Union[RequestOutput, PoolingRequestOutput]], Optional[SchedulerStats]]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
Expand All @@ -1396,7 +1397,7 @@ def _run_engine(
total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
step_outputs, scheduler_stats = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
Expand All @@ -1423,4 +1424,4 @@ def _run_engine(
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
return sorted(outputs, key=lambda x: int(x.request_id)), scheduler_stats
5 changes: 4 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)

self.spec_decoding_stats = SpecDecodingStats()

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -568,7 +570,8 @@ def update_from_output(

new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
spec_decoding_stats: Optional[SpecDecodingStats] = None
# spec_decoding_stats: Optional[SpecDecodingStats] = None
spec_decoding_stats = self.spec_decoding_stats
Comment on lines +573 to +574
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache the spec_decoding_stats so that it keeps a running metric instead of reinit it every engine step


# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
Expand Down
10 changes: 6 additions & 4 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Mapping
from copy import copy
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, Tuple

from typing_extensions import TypeVar

Expand All @@ -28,6 +28,7 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.stats import SchedulerStats

logger = init_logger(__name__)

Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False, # FIXME: implement
log_stats=True,
)

if not multiprocess_mode:
Expand Down Expand Up @@ -211,14 +212,15 @@ def add_request(
# Add the request to EngineCore.
self.engine_core.add_request(child_request)

def step(self) -> list[RequestOutput]:
def step(self) -> Tuple[list[RequestOutput], Optional[SchedulerStats]]:

if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
return []

# 1) Get EngineCoreOutput from the EngineCore.
# import pdb; pdb.set_trace() # REMOVE
outputs = self.engine_core.get_output()

# 2) Process EngineCoreOutputs.
Expand All @@ -228,7 +230,7 @@ def step(self) -> list[RequestOutput]:
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

return processed_outputs.request_outputs
return processed_outputs.request_outputs, outputs.scheduler_stats

def get_model_config(self):
return self.model_config
Expand Down