Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -360,6 +361,7 @@ def forward(
return hidden_states


@support_torch_compile
class ArcticModel(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to run this model successfully on H100, I have to change the config:

"hidden_size": 512,
"intermediate_size": 512,
"num_key_value_heads": 8,
"num_attention_heads": 8,
"num_local_experts": 4,

initially, I want to simply change "num_hidden_layers": 35, to "num_hidden_layers": 2, , but I met various random illegal memory access error. might be caused by fused moe kernel, with extremely large input sizes.


def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers import MixtralConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -245,6 +246,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class MixtralModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -239,6 +240,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class OlmoeModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from transformers.configuration_utils import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -429,6 +430,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class PhiMoEModel(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this model, it seems directly running it with -tp=2 will fail. the error is:

Failed: Cuda error /workspace/csrc/custom_all_reduce.cuh:336 'invalid argument'

need to investigate it later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this is unrelated to torch.compile


def __init__(
Expand Down