diff --git a/vllm/config.py b/vllm/config.py index e645103557c1..3ed1674b5f36 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2401,7 +2401,8 @@ def __post_init__(self): pass else: eagle_config = EAGLEConfig( - self.draft_model_config.hf_config) + self.draft_model_config.hf_config, + method=self.method) self.draft_model_config.hf_config = eagle_config if (self.num_speculative_tokens is not None diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 3a9ad3e0ffc8..586d5c7f5e54 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -15,6 +15,7 @@ class EAGLEConfig(PretrainedConfig): def __init__(self, model: Union[PretrainedConfig, dict, None] = None, truncated_vocab_size: Optional[int] = None, + method: Optional[str] = 'eagle', **kwargs): model_config: Union[PretrainedConfig, DeepseekV2Config, None] @@ -45,7 +46,23 @@ def __init__(self, if not envs.VLLM_USE_V1: kwargs["architectures"] = ["EAGLEModel"] else: - kwargs["architectures"] = ["EagleLlamaForCausalLM"] + # Eagle model name should follow naming convention of + # LlamaForCausalLM -> EagleLlamaForCausalLM + if method == "eagle": + assert self.model is not None, \ + "model should not be None when method is eagle" + kwargs["architectures"] = [ + f"Eagle{arch}" for arch in self.model.architectures + ] + elif method == "eagle3": + assert self.model is not None, \ + "model should not be None when method is eagle3" + kwargs["architectures"] = [ + f"Eagle3{arch}" for arch in self.model.architectures + ] + else: + raise ValueError(f"Invalid method {method}. \ + Supported methods are eagle and eagle3.") super().__init__(**kwargs) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1de14584d396..8c45ca9a319f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,8 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM -from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models import ModelRegistry from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -225,15 +224,11 @@ def load_model(self, target_model: nn.Module) -> None: with set_default_torch_dtype( draft_model_config.dtype), set_current_vllm_config( self.vllm_config): - if self.vllm_config.speculative_config.method == "eagle": - self.model = EagleLlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) - else: - assert self.vllm_config.speculative_config.method == "eagle3" - self.model = Eagle3LlamaForCausalLM( - model_config=draft_model_config, - start_layer_id=target_layer_num).to(target_device) + draft_model_cls, arch = ModelRegistry.resolve_model_cls( + draft_model_config.architectures) + self.model = draft_model_cls( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) loaded_weights = self.model.load_weights( loader.get_all_weights(