Skip to content

Commit ad6f780

Browse files
[torch.compile] expanding support and fix allgather compilation (#9637)
Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 295a061 commit ad6f780

File tree

6 files changed

+16
-1
lines changed

6 files changed

+16
-1
lines changed

vllm/distributed/parallel_state.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,15 +392,20 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
392392
# Convert negative dim to positive.
393393
dim += input_.dim()
394394
input_size = input_.size()
395+
# NOTE: we have to use concat-style all-gather here,
396+
# stack-style all-gather has compatibility issues with
397+
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
398+
output_size = (input_size[0] * world_size, ) + input_size[1:]
395399
# Allocate output tensor.
396-
output_tensor = torch.empty((world_size, ) + input_size,
400+
output_tensor = torch.empty(output_size,
397401
dtype=input_.dtype,
398402
device=input_.device)
399403
# All-gather.
400404
torch.distributed.all_gather_into_tensor(output_tensor,
401405
input_,
402406
group=self.device_group)
403407
# Reshape
408+
output_tensor = output_tensor.reshape((world_size, ) + input_size)
404409
output_tensor = output_tensor.movedim(0, dim)
405410
output_tensor = output_tensor.reshape(input_size[:dim] +
406411
(world_size *

vllm/model_executor/models/gpt_bigcode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers import GPTBigCodeConfig
2626

2727
from vllm.attention import Attention, AttentionMetadata
28+
from vllm.compilation.decorators import support_torch_compile
2829
from vllm.config import CacheConfig, LoRAConfig
2930
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
3031
from vllm.model_executor.layers.activation import get_act_fn
@@ -187,6 +188,7 @@ def forward(
187188
return hidden_states
188189

189190

191+
@support_torch_compile
190192
class GPTBigCodeModel(nn.Module):
191193

192194
def __init__(

vllm/model_executor/models/gpt_j.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from transformers import GPTJConfig
2424

2525
from vllm.attention import Attention, AttentionMetadata
26+
from vllm.compilation.decorators import support_torch_compile
2627
from vllm.config import CacheConfig
2728
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2829
from vllm.model_executor.layers.activation import get_act_fn
@@ -174,6 +175,7 @@ def forward(
174175
return hidden_states
175176

176177

178+
@support_torch_compile
177179
class GPTJModel(nn.Module):
178180

179181
def __init__(

vllm/model_executor/models/gpt_neox.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from transformers import GPTNeoXConfig
2424

2525
from vllm.attention import Attention, AttentionMetadata
26+
from vllm.compilation.decorators import support_torch_compile
2627
from vllm.config import CacheConfig
2728
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
2829
from vllm.model_executor.layers.activation import get_act_fn
@@ -187,6 +188,7 @@ def forward(
187188
return hidden_states
188189

189190

191+
@support_torch_compile
190192
class GPTNeoXModel(nn.Module):
191193

192194
def __init__(

vllm/model_executor/models/granite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from transformers import GraniteConfig
2929

3030
from vllm.attention import Attention, AttentionMetadata
31+
from vllm.compilation.decorators import support_torch_compile
3132
from vllm.config import CacheConfig, LoRAConfig
3233
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
3334
get_tensor_model_parallel_world_size)
@@ -254,6 +255,7 @@ def forward(
254255
return hidden_states
255256

256257

258+
@support_torch_compile
257259
class GraniteModel(nn.Module):
258260

259261
def __init__(

vllm/model_executor/models/internlm2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from transformers import PretrainedConfig
88

99
from vllm.attention import Attention, AttentionMetadata
10+
from vllm.compilation.decorators import support_torch_compile
1011
from vllm.config import CacheConfig
1112
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
1213
get_tensor_model_parallel_world_size,
@@ -230,6 +231,7 @@ def forward(
230231
return hidden_states, residual
231232

232233

234+
@support_torch_compile
233235
class InternLM2Model(nn.Module):
234236

235237
def __init__(

0 commit comments

Comments
 (0)