diff --git a/docs/source/features/spec_decode.md b/docs/source/features/spec_decode.md index cc8d6fceb7d6..852248e418ca 100644 --- a/docs/source/features/spec_decode.md +++ b/docs/source/features/spec_decode.md @@ -162,7 +162,7 @@ A variety of speculative models of this type are available on HF hub: ## Speculating using EAGLE based draft models The following code configures vLLM to use speculative decoding where proposals are generated by -an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. +an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](). ```python from vllm import LLM, SamplingParams diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py new file mode 100644 index 000000000000..baa91b2d0364 --- /dev/null +++ b/examples/offline_inference/eagle.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import json +import os + +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--dataset", + type=str, + default="./examples/data/gsm8k.jsonl", + help="downloaded from the eagle repo " \ + "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" +) +parser.add_argument("--max_num_seqs", type=int, default=8) +parser.add_argument("--num_prompts", type=int, default=80) +parser.add_argument("--num_spec_tokens", type=int, default=2) +parser.add_argument("--tp", type=int, default=1) +parser.add_argument("--draft_tp", type=int, default=1) +parser.add_argument("--enforce_eager", action='store_true') +parser.add_argument("--enable_chunked_prefill", action='store_true') +parser.add_argument("--max_num_batched_tokens", type=int, default=2048) +parser.add_argument("--temp", type=float, default=0) + +args = parser.parse_args() + +print(args) + +model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" +eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" + +max_model_len = 2048 + +tokenizer = AutoTokenizer.from_pretrained(model_dir) + +if os.path.exists(args.dataset): + prompts = [] + num_prompts = args.num_prompts + with open(args.dataset) as f: + for line in f: + data = json.loads(line) + prompts.append(data["turns"][0]) +else: + prompts = ["The future of AI is", "The president of the United States is"] + +prompts = prompts[:args.num_prompts] +num_prompts = len(prompts) + +prompt_ids = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + add_generation_prompt=True) + for prompt in prompts +] + +llm = LLM( + model=model_dir, + trust_remote_code=True, + tensor_parallel_size=args.tp, + enable_chunked_prefill=args.enable_chunked_prefill, + max_num_batched_tokens=args.max_num_batched_tokens, + enforce_eager=args.enforce_eager, + max_model_len=max_model_len, + max_num_seqs=args.max_num_seqs, + gpu_memory_utilization=0.8, + speculative_model=eagle_dir, + num_speculative_tokens=args.num_spec_tokens, + speculative_draft_tensor_parallel_size=args.draft_tp, + speculative_max_model_len=max_model_len, + disable_log_stats=False, +) + +sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) + +outputs = llm.generate(prompt_token_ids=prompt_ids, + sampling_params=sampling_params) + +# calculate the average number of accepted tokens per forward pass, +1 is +# to account for the token from the target model that's always going to be +# accepted +acceptance_counts = [0] * (args.num_spec_tokens + 1) +for output in outputs: + for step, count in enumerate(output.metrics.spec_token_acceptance_counts): + acceptance_counts[step] += count + +print(f"mean acceptance length: \ + {sum(acceptance_counts) / acceptance_counts[0]:.2f}") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 783275ab41d2..6517e78b86ee 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -830,6 +830,10 @@ def _create_sequence_group_with_sampling( self.generation_config_fields, seq.eos_token_id) # Create the sequence group. + draft_size = 1 + if self.vllm_config.speculative_config is not None: + draft_size = \ + self.vllm_config.speculative_config.num_speculative_tokens + 1 seq_group = SequenceGroup( request_id=request_id, seqs=[seq], @@ -839,7 +843,8 @@ def _create_sequence_group_with_sampling( trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, - priority=priority) + priority=priority, + draft_size=draft_size) return seq_group diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 8ceef855e020..4c5d78a43df6 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -100,6 +100,11 @@ def process_outputs(self, seqs = sequence_group.get_seqs( status=SequenceStatus.FINISHED_ABORTED) + for output in outputs: + if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID: + sequence_group.metrics.spec_token_acceptance_counts[ + output.step_index] += 1 + assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" assert len(seqs) == 1, ( "Beam search not supported in multi-step decoding.") diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index f2a2935e6c69..010e51a3b9f2 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -38,7 +38,7 @@ def forward(self, x, residual): if residual is None: return x else: - return x, residual + return x + residual, None class EAGLE(nn.Module): diff --git a/vllm/sequence.py b/vllm/sequence.py index 6a7b1e62a604..61867b025315 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -111,6 +111,13 @@ class RequestMetrics: model_execute_time: The time spent in the model execute function. This will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. + spec_token_acceptance_counts: number of accepted speculative tokens at + each position; the first token is from + the target model and is always accepted; + e.g., when it's [10, 8, 4, 2] for a req, + it means there were 10 forward passes in + total, and there were 8, 4, 2 accepted + tokens at 1st, 2nd, 3rd speculation step. """ arrival_time: float last_token_time: float @@ -121,6 +128,7 @@ class RequestMetrics: scheduler_time: Optional[float] = None model_forward_time: Optional[float] = None model_execute_time: Optional[float] = None + spec_token_acceptance_counts: Optional[list[int]] = None class SequenceDataDelta( @@ -639,22 +647,25 @@ class SequenceGroup: trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. priority: User-defined priority of the request. + draft_size: The number of speculative tokens plus one from the target + model; equal to max number of tokens a step can generate + for single-draft speculative decoding but larger than + that for multi-draft SD (currently not supported). """ - def __init__( - self, - request_id: str, - seqs: list[Sequence], - arrival_time: float, - sampling_params: Optional[SamplingParams] = None, - lora_request: Optional[LoRARequest] = None, - pooling_params: Optional[PoolingParams] = None, - pooled_data: Optional[torch.Tensor] = None, - encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: + def __init__(self, + request_id: str, + seqs: list[Sequence], + arrival_time: float, + sampling_params: Optional[SamplingParams] = None, + lora_request: Optional[LoRARequest] = None, + pooling_params: Optional[PoolingParams] = None, + pooled_data: Optional[torch.Tensor] = None, + encoder_seq: Optional[Sequence] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + draft_size: int = 1) -> None: self.request_id = request_id self.seqs = seqs self.first_seq = seqs[0] @@ -667,7 +678,9 @@ def __init__( last_token_time=arrival_time, first_scheduled_time=None, first_token_time=None, - time_in_queue=None) + time_in_queue=None, + spec_token_acceptance_counts=[0] * + draft_size) self.last_token_latency = 0.0 self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None @@ -1079,6 +1092,7 @@ class CompletionSequenceGroupOutput( samples: list[SequenceOutput] # Prompt logprob for each prompt query token. prompt_logprobs: Optional[PromptLogprobs] + step_index: Optional[int] = 0 def __repr__(self) -> str: return (f"CompletionSequenceGroupOutput(samples={self.samples}, " diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8909a41bc99f..5bf4f67d35bd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1080,7 +1080,7 @@ def _create_output_sampler_list( [sequence_index][:num_logprobs], topk_logprobs=topk_logprobs_by_step[step_index] [sequence_index][:num_logprobs], - )) + step_index=step_index)) sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 9c04680a6a7a..466269b2107f 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -93,14 +93,14 @@ def create_logprobs_output( def create_sequence_group_output( - token_id: int, - token_id_logprob_rank: int, - token_id_logprob: float, - seq_id: SeqId, - topk_token_ids: List[Optional[int]], - topk_logprobs: List[Optional[float]], - prompt_logprobs: Optional[PromptLogprobs] = None, -) -> CompletionSequenceGroupOutput: + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, + step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput: """Create a SequenceGroupOutput given the sampling results. Args: @@ -110,6 +110,7 @@ def create_sequence_group_output( seq_id (int): The sequence id. topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + step_index: (Optional[int]): The index of the speculative token. """ logprobs = create_logprobs_output( @@ -120,14 +121,13 @@ def create_sequence_group_output( topk_logprobs, ) - return CompletionSequenceGroupOutput( - samples=[ - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs=logprobs) - ], - prompt_logprobs=prompt_logprobs, - ) + return CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + prompt_logprobs=prompt_logprobs, + step_index=step_index) def split_batch_by_proposal_len(