Skip to content

Commit 9e1b6d2

Browse files
CRZbulabulayoukaichao
authored andcommitted
[torch.compile] Adding torch compile annotations to some models (vllm-project#9876)
Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]> Signed-off-by: s.kochetkov <[email protected]>
1 parent f7cc665 commit 9e1b6d2

File tree

7 files changed

+12
-2
lines changed

7 files changed

+12
-2
lines changed

docs/source/models/supported_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ Text Generation
281281
- ✅︎
282282
* - :code:`Qwen2ForCausalLM`
283283
- Qwen2
284-
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
284+
- :code:`Qwen/Qwen2-7B-Instruct`, :code:`Qwen/Qwen2-7B`, etc.
285285
- ✅︎
286286
- ✅︎
287287
* - :code:`Qwen2MoeForCausalLM`

tests/distributed/test_pipeline_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def iter_params(self, model_name: str):
166166
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
167167
"adept/persimmon-8b-chat": PPTestSettings.fast(),
168168
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
169-
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
169+
"Qwen/Qwen2-7B-Instruct": PPTestSettings.fast(),
170170
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
171171
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
172172
"bigcode/starcoder2-3b": PPTestSettings.fast(),

vllm/model_executor/models/falcon.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers import FalconConfig as HF_FalconConfig
2828

2929
from vllm.attention import Attention, AttentionMetadata
30+
from vllm.compilation.decorators import support_torch_compile
3031
from vllm.config import CacheConfig
3132
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3233
get_tensor_model_parallel_world_size,
@@ -329,6 +330,7 @@ def forward(
329330
return output
330331

331332

333+
@support_torch_compile
332334
class FalconModel(nn.Module):
333335

334336
def __init__(

vllm/model_executor/models/phi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from transformers import PhiConfig
4343

4444
from vllm.attention import Attention, AttentionMetadata
45+
from vllm.compilation.decorators import support_torch_compile
4546
from vllm.config import CacheConfig, LoRAConfig
4647
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
4748
from vllm.model_executor.layers.activation import get_act_fn
@@ -193,6 +194,7 @@ def forward(
193194
return hidden_states
194195

195196

197+
@support_torch_compile
196198
class PhiModel(nn.Module):
197199

198200
def __init__(self,

vllm/model_executor/models/qwen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from transformers import PretrainedConfig
2121

2222
from vllm.attention import Attention, AttentionMetadata
23+
from vllm.compilation.decorators import support_torch_compile
2324
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
2425
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2526
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
@@ -549,6 +550,7 @@ def forward(
549550
return hidden_states, residual
550551

551552

553+
@support_torch_compile
552554
class QWenModel(nn.Module):
553555

554556
def __init__(

vllm/model_executor/models/qwen2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from transformers import Qwen2Config
3030

3131
from vllm.attention import Attention, AttentionMetadata
32+
from vllm.compilation.decorators import support_torch_compile
3233
from vllm.config import CacheConfig, LoRAConfig
3334
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3435
from vllm.model_executor.layers.activation import SiluAndMul
@@ -237,6 +238,7 @@ def forward(
237238
return hidden_states, residual
238239

239240

241+
@support_torch_compile
240242
class Qwen2Model(nn.Module):
241243

242244
def __init__(

vllm/model_executor/models/qwen2_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from transformers import PretrainedConfig
3131

3232
from vllm.attention import Attention, AttentionMetadata
33+
from vllm.compilation.decorators import support_torch_compile
3334
from vllm.config import CacheConfig
3435
from vllm.distributed import (get_pp_group,
3536
get_tensor_model_parallel_world_size,
@@ -312,6 +313,7 @@ def forward(
312313
return hidden_states, residual
313314

314315

316+
@support_torch_compile
315317
class Qwen2MoeModel(nn.Module):
316318

317319
def __init__(

0 commit comments

Comments
 (0)