Skip to content

Commit f016102

Browse files
Isotr0pygarg-amit
authored andcommitted
[Bugfix] Fix Fuyu tensor parallel inference (vllm-project#8986)
Signed-off-by: Amit Garg <[email protected]>
1 parent db644b7 commit f016102

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
3838
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
3939
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
40-
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
40+
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"),
41+
# TP only models
42+
(2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"),
4143
],
4244
)
4345
@fork_new_process_for_each_test

vllm/model_executor/models/fuyu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,9 @@ def __init__(self,
237237
self.image_feature_size,
238238
config.hidden_size,
239239
quant_config=quant_config,
240+
gather_output=True,
240241
)
241-
self.language_model = PersimmonForCausalLM(config,
242+
self.language_model = PersimmonForCausalLM(config.text_config,
242243
cache_config=cache_config,
243244
quant_config=quant_config)
244245

vllm/model_executor/models/persimmon.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import torch
2626
from torch import nn
2727
from transformers import PersimmonConfig
28-
from transformers.activations import ReLUSquaredActivation
2928

3029
from vllm.attention import Attention, AttentionMetadata
3130
from vllm.config import CacheConfig
3231
from vllm.distributed import get_tensor_model_parallel_world_size
32+
from vllm.model_executor.layers.activation import get_act_fn
3333
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
3434
QKVParallelLinear,
3535
RowParallelLinear)
@@ -57,7 +57,7 @@ def __init__(self,
5757
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
5858
config.hidden_size,
5959
quant_config=quant_config)
60-
self.act = ReLUSquaredActivation()
60+
self.act = get_act_fn(config.hidden_act, quant_config)
6161

6262
def forward(self, hidden_states) -> torch.Tensor:
6363
hidden_states, _ = self.dense_h_to_4h(hidden_states)
@@ -96,7 +96,7 @@ def __init__(self,
9696
quant_config=quant_config,
9797
)
9898
self.dense = RowParallelLinear(
99-
self.num_heads * self.head_dim,
99+
self.total_num_heads * self.head_dim,
100100
self.hidden_size,
101101
bias=True,
102102
quant_config=quant_config,
@@ -213,10 +213,10 @@ def __init__(self,
213213
cache_config: Optional[CacheConfig] = None,
214214
quant_config: Optional[QuantizationConfig] = None):
215215
super().__init__()
216-
self.vocab_size = config.text_config.vocab_size
216+
self.vocab_size = config.vocab_size
217217

218-
self.embed_tokens = VocabParallelEmbedding(
219-
config.text_config.vocab_size, config.hidden_size)
218+
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
219+
config.hidden_size)
220220
self.layers = nn.ModuleList([
221221
PersimmonDecoderLayer(config,
222222
cache_config=cache_config,
@@ -252,19 +252,19 @@ def forward(
252252
class PersimmonForCausalLM(nn.Module):
253253

254254
def __init__(self,
255-
config,
255+
config: PersimmonConfig,
256256
cache_config: Optional[CacheConfig] = None,
257257
quant_config: Optional[QuantizationConfig] = None):
258258
super().__init__()
259259
self.config = config
260-
self.vocab_size = config.text_config.vocab_size
260+
self.vocab_size = config.vocab_size
261261
self.model = PersimmonModel(config,
262262
cache_config=cache_config,
263263
quant_config=quant_config)
264-
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
264+
self.lm_head = ParallelLMHead(config.vocab_size,
265265
config.hidden_size,
266266
bias=False)
267-
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
267+
self.logits_processor = LogitsProcessor(config.vocab_size)
268268
self.sampler = Sampler()
269269

270270
def forward(

0 commit comments

Comments
 (0)