Skip to content

Commit 0595d75

Browse files
CRZbulabulaweilong.yu
authored andcommitted
Adding "torch compile" annotations to moe models (vllm-project#9758)
1 parent d330b01 commit 0595d75

File tree

4 files changed

+8
-0
lines changed

4 files changed

+8
-0
lines changed

vllm/model_executor/models/arctic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch import nn
66

77
from vllm.attention import Attention, AttentionMetadata
8+
from vllm.compilation.decorators import support_torch_compile
89
from vllm.config import CacheConfig
910
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
1011
get_tensor_model_parallel_world_size,
@@ -360,6 +361,7 @@ def forward(
360361
return hidden_states
361362

362363

364+
@support_torch_compile
363365
class ArcticModel(nn.Module):
364366

365367
def __init__(

vllm/model_executor/models/mixtral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from transformers import MixtralConfig
2929

3030
from vllm.attention import Attention, AttentionMetadata
31+
from vllm.compilation.decorators import support_torch_compile
3132
from vllm.config import CacheConfig, LoRAConfig
3233
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3334
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -245,6 +246,7 @@ def forward(
245246
return hidden_states, residual
246247

247248

249+
@support_torch_compile
248250
class MixtralModel(nn.Module):
249251

250252
def __init__(

vllm/model_executor/models/olmoe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from transformers import PretrainedConfig
1818

1919
from vllm.attention import Attention, AttentionMetadata
20+
from vllm.compilation.decorators import support_torch_compile
2021
from vllm.config import CacheConfig
2122
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2223
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -239,6 +240,7 @@ def forward(
239240
return hidden_states, residual
240241

241242

243+
@support_torch_compile
242244
class OlmoeModel(nn.Module):
243245

244246
def __init__(

vllm/model_executor/models/phimoe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from transformers.configuration_utils import PretrainedConfig
2929

3030
from vllm.attention import Attention, AttentionMetadata
31+
from vllm.compilation.decorators import support_torch_compile
3132
from vllm.config import CacheConfig, LoRAConfig
3233
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3334
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -429,6 +430,7 @@ def forward(
429430
return hidden_states, residual
430431

431432

433+
@support_torch_compile
432434
class PhiMoEModel(nn.Module):
433435

434436
def __init__(

0 commit comments

Comments
 (0)