Skip to content

Commit 98caed8

Browse files
rahul-tulidevpatelio
authored andcommitted
Add: Support for multiple hidden layers in Eagle3 (vllm-project#26164)
Signed-off-by: Rahul Tuli <[email protected]>
1 parent 63631e4 commit 98caed8

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

tests/speculative_decoding/speculators/test_eagle3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
2323
id="qwen3-eagle3-speculator-w4a16-verifier",
2424
),
25+
pytest.param(
26+
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
27+
id="llama3-eagl3-multiple-layers",
28+
),
2529
],
2630
)
2731
def test_eagle3_speculators_model(

vllm/model_executor/models/llama_eagle3.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,20 @@ def __init__(
3535
vllm_config: VllmConfig,
3636
prefix: str = "",
3737
config: Optional[LlamaConfig] = None,
38+
layer_idx: int = 0,
3839
) -> None:
3940
super().__init__(vllm_config, prefix=prefix, config=config)
4041

4142
config = config or vllm_config.model_config.hf_config
4243
quant_config = self.get_quant_config(vllm_config)
4344

45+
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
46+
# Subsequent layers use hidden_size (only hidden_states, no embeds)
47+
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
48+
4449
# override qkv
4550
self.self_attn.qkv_proj = QKVParallelLinear(
46-
2 * self.hidden_size,
51+
qkv_input_size,
4752
self.self_attn.head_dim,
4853
self.self_attn.total_num_heads,
4954
self.self_attn.total_num_kv_heads,
@@ -53,6 +58,7 @@ def __init__(
5358
)
5459

5560
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61+
self.layer_idx = layer_idx
5662

5763
if getattr(config, "norm_before_residual", False):
5864
self._residual_norm = self._norm_before_residual
@@ -91,11 +97,15 @@ def forward(
9197
hidden_states: torch.Tensor,
9298
residual: Optional[torch.Tensor],
9399
) -> tuple[torch.Tensor, torch.Tensor]:
94-
embeds = self.input_layernorm(embeds)
95-
96-
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
100+
if self.layer_idx == 0:
101+
# First layer: concatenate embeds with hidden_states
102+
embeds = self.input_layernorm(embeds)
103+
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
104+
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
105+
else:
106+
# Subsequent layers: process hidden_states and residuals only
107+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
97108

98-
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
99109
# Self Attention
100110
hidden_states = self.self_attn(
101111
positions=positions,
@@ -134,9 +144,11 @@ def __init__(
134144
[
135145
LlamaDecoderLayer(
136146
current_vllm_config,
137-
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
147+
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
138148
config=self.config,
149+
layer_idx=layer_idx,
139150
)
151+
for layer_idx in range(self.config.num_hidden_layers)
140152
]
141153
)
142154
if hasattr(self.config, "target_hidden_size"):
@@ -167,13 +179,13 @@ def forward(
167179
assert hidden_states.shape[-1] == input_embeds.shape[-1]
168180

169181
residual = None
170-
hidden_states, residual = self.layers[0](
171-
positions,
172-
input_embeds,
173-
hidden_states,
174-
residual,
175-
)
176-
182+
for layer in self.layers:
183+
hidden_states, residual = layer(
184+
positions=positions,
185+
embeds=input_embeds,
186+
hidden_states=hidden_states,
187+
residual=residual,
188+
)
177189
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
178190
return hidden_states, hidden_prenorm
179191

0 commit comments

Comments
 (0)