From e8b1ab60f29718a5aa3791170ef964a995dec3a9 Mon Sep 17 00:00:00 2001 From: qizixi Date: Wed, 30 Apr 2025 14:21:49 -0700 Subject: [PATCH 1/3] Apply torch.compile & cudagraph to EAGLE3 Signed-off-by: qizixi --- vllm/model_executor/models/llama_eagle3.py | 24 +++++++++++------ vllm/v1/spec_decode/eagle.py | 30 +++++++++++----------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index c42f19fee17d..8c7d98f8618b 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -6,6 +6,7 @@ import torch.nn as nn from transformers import LlamaConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -75,18 +76,19 @@ def forward( return hidden_states, residual - +@support_torch_compile class LlamaModel(nn.Module): def __init__( self, *, - model_config: ModelConfig, + vllm_config: VllmConfig, start_layer_id: int = 0, prefix: str = "", ) -> None: super().__init__() - self.config = model_config.hf_config + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, @@ -119,8 +121,7 @@ def forward( hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) - if (hidden_states.shape[-1] != input_embeds.shape[-1]): - hidden_states = self.fc(hidden_states) + assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None hidden_states, residual = self.layers[0]( @@ -169,9 +170,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): nn.Module.__init__(self) - model_config = vllm_config.speculative_config.draft_model_config - self.config = model_config.hf_config - self.model = LlamaModel(model_config=model_config, + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + self.model = LlamaModel(vllm_config=vllm_config, start_layer_id=start_layer_id, prefix="model") @@ -214,6 +215,13 @@ def compute_logits( logits_new[:, targets] = logits return logits_new + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # combine multiple auxiliary hidden states returned by eagle3 + return self.model.fc(hidden_states) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader( self, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 81508c2e069b..e81507498caf 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models import ModelRegistry from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -39,11 +40,10 @@ def __init__( self.hidden_size = vllm_config.model_config.get_hidden_size() - # TODO: make eagle3 compatible with cudagraph - self.use_cuda_graph = self.method != 'eagle3' and \ - (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) + self.use_cuda_graph = ( + self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) self.cudagraph_batch_sizes = list( reversed( @@ -90,6 +90,12 @@ def propose( batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 + if self.method == "eagle3": + assert isinstance(self.model, Eagle3LlamaForCausalLM) + target_hidden_states = self.model.combine_hidden_states( + target_hidden_states) + assert target_hidden_states.shape[-1] == self.hidden_size + # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] self.input_ids[:num_tokens - 1] = target_token_ids[1:] @@ -126,12 +132,8 @@ def propose( # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - if self.method == 'eagle': - self.hidden_states[:num_tokens] = target_hidden_states - hidden_states = self.hidden_states - else: - # TODO: make eagle3 compatible with cuda graph - hidden_states = target_hidden_states + self.hidden_states[:num_tokens] = target_hidden_states + hidden_states = self.hidden_states with set_forward_context(attn_metadata, self.vllm_config, @@ -209,10 +211,8 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions - if self.method == 'eagle': - # TODO: make eagle3 compatible with cudagraph. - self.hidden_states[:batch_size] = hidden_states - hidden_states = self.hidden_states + self.hidden_states[:batch_size] = hidden_states + hidden_states = self.hidden_states # Run the model. with set_forward_context(attn_metadata, From 99fc9c9bb7ef658b09f88cb69b011a41d9f69780 Mon Sep 17 00:00:00 2001 From: qizixi Date: Wed, 30 Apr 2025 14:21:49 -0700 Subject: [PATCH 2/3] Apply torch.compile & cudagraph to EAGLE3 Signed-off-by: qizixi --- vllm/model_executor/models/llama_eagle3.py | 3 ++- vllm/v1/spec_decode/eagle.py | 15 ++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 8c7d98f8618b..904ff3210943 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -7,7 +7,7 @@ from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import ModelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -76,6 +76,7 @@ def forward( return hidden_states, residual + @support_torch_compile class LlamaModel(nn.Module): diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e81507498caf..cc83d39385da 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,8 +9,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -40,10 +40,9 @@ def __init__( self.hidden_size = vllm_config.model_config.get_hidden_size() - self.use_cuda_graph = ( - self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager) + self.use_cuda_graph = (self.vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not self.vllm_config.model_config.enforce_eager) self.cudagraph_batch_sizes = list( reversed( @@ -133,7 +132,6 @@ def propose( self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - hidden_states = self.hidden_states with set_forward_context(attn_metadata, self.vllm_config, @@ -141,7 +139,7 @@ def propose( last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], - hidden_states=hidden_states[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -212,7 +210,6 @@ def propose( self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states - hidden_states = self.hidden_states # Run the model. with set_forward_context(attn_metadata, @@ -221,7 +218,7 @@ def propose( last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:input_batch_size], positions=self.positions[:input_batch_size], - hidden_states=hidden_states[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], From 0fccfe53da3ae517d7bcc18f5282ea811f58b1d3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 1 May 2025 12:58:38 -0700 Subject: [PATCH 3/3] fix Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cc83d39385da..07097d7da68f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -311,12 +311,11 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): - if self.method == 'eagle': - self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], - ) + self.model( + input_ids=self.input_ids[:num_tokens], + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) # NOTE(woosuk): Currently, the below code is not used and we always use argmax