diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index c42f19fee17d..904ff3210943 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -6,7 +6,8 @@ import torch.nn as nn from transformers import LlamaConfig -from vllm.config import ModelConfig, VllmConfig +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -76,17 +77,19 @@ def forward( return hidden_states, residual +@support_torch_compile class LlamaModel(nn.Module): def __init__( self, *, - model_config: ModelConfig, + vllm_config: VllmConfig, start_layer_id: int = 0, prefix: str = "", ) -> None: super().__init__() - self.config = model_config.hf_config + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -119,8 +122,7 @@ def forward( hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) - if (hidden_states.shape[-1] != input_embeds.shape[-1]): - hidden_states = self.fc(hidden_states) + assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None hidden_states, residual = self.layers[0]( @@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): nn.Module.__init__(self) - model_config = vllm_config.speculative_config.draft_model_config - self.config = model_config.hf_config - self.model = LlamaModel(model_config=model_config, + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + self.model = LlamaModel(vllm_config=vllm_config, start_layer_id=start_layer_id, prefix="model") @@ -214,6 +216,13 @@ def compute_logits( logits_new[:, targets] = logits return logits_new + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # combine multiple auxiliary hidden states returned by eagle3 + return self.model.fc(hidden_states) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 81508c2e069b..07097d7da68f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,6 +10,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -39,11 +40,9 @@ def __init__( self.hidden_size = vllm_config.model_config.get_hidden_size() - # TODO: make eagle3 compatible with cudagraph - self.use_cuda_graph = self.method != 'eagle3' and \ - (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) self.cudagraph_batch_sizes = list( reversed( @@ -90,6 +89,12 @@ def propose( batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] self.input_ids[:num_tokens - 1] = target_token_ids[1:] @@ -126,12 +131,7 @@ def propose( # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - if self.method == 'eagle': - self.hidden_states[:num_tokens] = target_hidden_states - hidden_states = self.hidden_states - else: - # TODO: make eagle3 compatible with cuda graph - hidden_states = target_hidden_states + self.hidden_states[:num_tokens] = target_hidden_states with set_forward_context(attn_metadata, self.vllm_config, @@ -139,7 +139,7 @@ def propose( last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], - hidden_states=hidden_states[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -209,10 +209,7 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions - if self.method == 'eagle': - # TODO: make eagle3 compatible with cudagraph. - self.hidden_states[:batch_size] = hidden_states - hidden_states = self.hidden_states + self.hidden_states[:batch_size] = hidden_states # Run the model. with set_forward_context(attn_metadata, @@ -221,7 +218,7 @@ def propose( last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:input_batch_size], positions=self.positions[:input_batch_size], - hidden_states=hidden_states[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], @@ -314,12 +311,11 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - if self.method == 'eagle': - self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], - ) + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) # NOTE(woosuk): Currently, the below code is not used and we always use argmax