File tree Expand file tree Collapse file tree 2 files changed +17
-10
lines changed
Expand file tree Collapse file tree 2 files changed +17
-10
lines changed Original file line number Diff line number Diff line change 2323 maybe_remap_kv_scale_name ,
2424)
2525from vllm .model_executor .models .llama import LlamaDecoderLayer , LlamaForCausalLM
26- from vllm .multimodal import MULTIMODAL_REGISTRY
2726from vllm .multimodal .inputs import NestedTensors
2827
2928from .utils import (
@@ -121,13 +120,12 @@ def forward(
121120
122121
123122@support_torch_compile (
124- # torch.compile is disabled for multimodal EAGLE3 models due to constraint
125- # violations with dynamic shapes during tensor concatenation operations.
126- # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
127- # Non-multimodal EAGLE3 models can still use torch.compile safely.
128- enable_if = lambda vllm_config : not MULTIMODAL_REGISTRY .supports_multimodal_inputs (
129- vllm_config .model_config
130- ),
123+ dynamic_arg_dims = {
124+ "input_ids" : 0 ,
125+ "positions" : - 1 ,
126+ "hidden_states" : 0 ,
127+ "input_embeds" : 0 ,
128+ }
131129)
132130class LlamaModel (nn .Module ):
133131 def __init__ (
Original file line number Diff line number Diff line change @@ -116,9 +116,18 @@ def __init__(
116116 )
117117 self .uses_mrope = self .vllm_config .model_config .uses_mrope
118118 if self .uses_mrope :
119- # M-RoPE need (3, max_num_tokens)
119+ # NOTE: `mrope_positions` is implemented with one additional dummy
120+ # position on purpose to make it non-contiguous so that it can work
121+ # with torch compile.
122+ # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
123+
124+ # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
125+ # the modality of inputs. For text-only inputs, each dimension has
126+ # identical position IDs, making M-RoPE functionally equivalent to
127+ # 1D-RoPE.
128+ # See page 5 of https://arxiv.org/abs/2409.12191
120129 self .mrope_positions = torch .zeros (
121- (3 , self .max_num_tokens ), dtype = torch .int64 , device = device
130+ (3 , self .max_num_tokens + 1 ), dtype = torch .int64 , device = device
122131 )
123132 else :
124133 # RoPE need (max_num_tokens,)
You can’t perform that action at this time.
0 commit comments