Skip to content

Commit e6c92d9

Browse files
CRZbulabulayoukaichao
authored andcommitted
[torch.compile] Adding torch compile annotations to some models (vllm-project#9639)
Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]> Signed-off-by: qishuai <[email protected]>
1 parent a432846 commit e6c92d9

File tree

7 files changed

+13
-3
lines changed

7 files changed

+13
-3
lines changed

docs/source/models/supported_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ Text Generation
144144
- ✅︎
145145
* - :code:`JAISLMHeadModel`
146146
- Jais
147-
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
147+
- :code:`inceptionai/jais-13b`, :code:`inceptionai/jais-13b-chat`, :code:`inceptionai/jais-30b-v3`, :code:`inceptionai/jais-30b-chat-v3`, etc.
148148
-
149149
- ✅︎
150150
* - :code:`JambaForCausalLM`

tests/distributed/test_pipeline_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def iter_params(self, model_name: str):
145145
# Uses Llama
146146
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
147147
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
148-
"core42/jais-13b-chat": PPTestSettings.fast(),
148+
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
149149
# TODO: Implement PP
150150
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
151151
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),

vllm/model_executor/models/jais.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# coding=utf-8
22
# Adapted from
3-
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
3+
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
44
# Copyright 2023 The vLLM team.
55
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
66
# reserved.
@@ -26,6 +26,7 @@
2626
from torch import nn
2727

2828
from vllm.attention import Attention, AttentionMetadata
29+
from vllm.compilation.decorators import support_torch_compile
2930
from vllm.config import CacheConfig
3031
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3132
get_tensor_model_parallel_world_size)
@@ -212,6 +213,7 @@ def forward(
212213
return hidden_states
213214

214215

216+
@support_torch_compile
215217
class JAISModel(nn.Module):
216218

217219
def __init__(

vllm/model_executor/models/minicpm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from transformers import PretrainedConfig
3030

3131
from vllm.attention import Attention, AttentionMetadata
32+
from vllm.compilation.decorators import support_torch_compile
3233
from vllm.config import CacheConfig, LoRAConfig
3334
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3435
get_tensor_model_parallel_world_size,
@@ -348,6 +349,7 @@ def forward(
348349
return hidden_states, None
349350

350351

352+
@support_torch_compile
351353
class MiniCPMModel(nn.Module):
352354

353355
def __init__(

vllm/model_executor/models/mpt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn as nn
88

99
from vllm.attention import Attention, AttentionMetadata
10+
from vllm.compilation.decorators import support_torch_compile
1011
from vllm.config import CacheConfig
1112
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
1213
get_tensor_model_parallel_world_size)
@@ -204,6 +205,7 @@ def forward(
204205
return hidden_states
205206

206207

208+
@support_torch_compile
207209
class MPTModel(nn.Module):
208210

209211
def __init__(

vllm/model_executor/models/nemotron.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch import nn
2828

2929
from vllm.attention import Attention, AttentionMetadata
30+
from vllm.compilation.decorators import support_torch_compile
3031
from vllm.config import CacheConfig, LoRAConfig
3132
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3233
from vllm.model_executor.layers.activation import get_act_fn
@@ -290,6 +291,7 @@ def forward(
290291
return hidden_states, residual
291292

292293

294+
@support_torch_compile
293295
class NemotronModel(nn.Module):
294296

295297
def __init__(

vllm/model_executor/models/olmo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from transformers import OlmoConfig
2929

3030
from vllm.attention import Attention, AttentionMetadata
31+
from vllm.compilation.decorators import support_torch_compile
3132
from vllm.config import CacheConfig
3233
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3334
from vllm.model_executor.layers.activation import SiluAndMul
@@ -221,6 +222,7 @@ def forward(
221222
return hidden_states
222223

223224

225+
@support_torch_compile
224226
class OlmoModel(nn.Module):
225227

226228
def __init__(self,

0 commit comments

Comments
 (0)