Skip to content

Commit 6d2891c

Browse files
IzzyPuttermanDarkLight1337
authored andcommitted
Eagle: MM Cuda Graphs with MRope (vllm-project#28896)
Signed-off-by: Izzy Putterman <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 82265f8 commit 6d2891c

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

vllm/model_executor/models/llama_eagle3.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
maybe_remap_kv_scale_name,
2424
)
2525
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
26-
from vllm.multimodal import MULTIMODAL_REGISTRY
2726
from vllm.multimodal.inputs import NestedTensors
2827

2928
from .utils import (
@@ -121,13 +120,12 @@ def forward(
121120

122121

123122
@support_torch_compile(
124-
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
125-
# violations with dynamic shapes during tensor concatenation operations.
126-
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
127-
# Non-multimodal EAGLE3 models can still use torch.compile safely.
128-
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
129-
vllm_config.model_config
130-
),
123+
dynamic_arg_dims={
124+
"input_ids": 0,
125+
"positions": -1,
126+
"hidden_states": 0,
127+
"input_embeds": 0,
128+
}
131129
)
132130
class LlamaModel(nn.Module):
133131
def __init__(

vllm/v1/spec_decode/eagle.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,18 @@ def __init__(
116116
)
117117
self.uses_mrope = self.vllm_config.model_config.uses_mrope
118118
if self.uses_mrope:
119-
# M-RoPE need (3, max_num_tokens)
119+
# NOTE: `mrope_positions` is implemented with one additional dummy
120+
# position on purpose to make it non-contiguous so that it can work
121+
# with torch compile.
122+
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
123+
124+
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
125+
# the modality of inputs. For text-only inputs, each dimension has
126+
# identical position IDs, making M-RoPE functionally equivalent to
127+
# 1D-RoPE.
128+
# See page 5 of https://arxiv.org/abs/2409.12191
120129
self.mrope_positions = torch.zeros(
121-
(3, self.max_num_tokens), dtype=torch.int64, device=device
130+
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
122131
)
123132
else:
124133
# RoPE need (max_num_tokens,)

0 commit comments

Comments
 (0)