Skip to content

Commit aadb350

Browse files
NickLucchemzusman
authored andcommitted
[Bugfix] Fix RobertaModel loading (vllm-project#11940)
Signed-off-by: NickLucche <[email protected]>
1 parent accb858 commit aadb350

File tree

3 files changed

+67
-12
lines changed

3 files changed

+67
-12
lines changed

tests/model_executor/test_model_load_with_params.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from vllm.model_executor.layers.pooler import PoolingType
5+
from vllm.model_executor.layers.pooler import CLSPool, PoolingType
66
from vllm.model_executor.models.bert import BertEmbeddingModel
77
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
88
from vllm.platforms import current_platform
@@ -92,3 +92,28 @@ def test_roberta_model_loading_with_params(vllm_runner):
9292

9393
# assert output
9494
assert output
95+
96+
97+
@pytest.mark.skipif(current_platform.is_rocm(),
98+
reason="Xformers backend is not supported on ROCm.")
99+
def test_facebook_roberta_model_loading_with_params(vllm_runner):
100+
"""
101+
Test loading roberta-base model with no lm_head.
102+
"""
103+
model_name = "FacebookAI/roberta-base"
104+
with vllm_runner(model_name=model_name,
105+
dtype="float16",
106+
max_model_len=MAX_MODEL_LEN) as model:
107+
output = model.encode("Write a short story about a robot that"
108+
" dreams for the first time.\n")
109+
110+
model_tokenizer = model.model.llm_engine.tokenizer
111+
assert model_tokenizer.tokenizer_id == model_name
112+
113+
model = model.model.llm_engine.model_executor\
114+
.driver_worker.model_runner.model
115+
assert not hasattr(model, "lm_head")
116+
assert isinstance(model, RobertaEmbeddingModel)
117+
assert isinstance(model._pooler, CLSPool)
118+
119+
assert output

tests/models/embedding/language/test_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
2626
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
2727
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
28+
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
2829
],
2930
)
3031
@pytest.mark.parametrize("dtype", ["half"])

vllm/model_executor/models/roberta.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from typing import Iterable, List, Optional, Tuple
23

34
import torch
@@ -20,6 +21,30 @@
2021
from .interfaces import SupportsCrossEncoding
2122

2223

24+
def roberta_task_weights_filter(
25+
all_weights: Iterable[Tuple[str, torch.Tensor]]
26+
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
27+
torch.Tensor]]]:
28+
"""
29+
Separate task-specific weights that are applied on top
30+
of the encoder-decoder bert base.
31+
To do so, return two generators over the original iterator.
32+
Also, remove the "roberta." prefix to make it loadable
33+
from vanilla BertModel.
34+
"""
35+
# Copy of a lazy iterator without in-memory overhead so both
36+
# iterators can be iterated upon independently.
37+
all_weights1, all_weights2 = itertools.tee(all_weights)
38+
39+
def encoder_decoder_weights():
40+
for name, weight in all_weights1:
41+
if name.startswith("roberta."):
42+
yield (name[len("roberta."):], weight)
43+
44+
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
45+
if not n.startswith("roberta."))
46+
47+
2348
class RobertaEmbedding(nn.Module):
2449

2550
def __init__(self, config: RobertaConfig):
@@ -152,6 +177,18 @@ def _build_model(self,
152177
prefix=prefix,
153178
embedding_class=RobertaEmbedding)
154179

180+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
181+
weights = self.hf_to_vllm_mapper.apply(weights)
182+
# Separate weights in "roberta"-prefixed and all else (not in memory).
183+
# For use with models like FacebookAI/roberta-base.
184+
bert_weights, task_weights = roberta_task_weights_filter(weights)
185+
loaded = self.model.load_weights(bert_weights)
186+
if not len(loaded):
187+
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
188+
# which use the same architecture, but have no "roberta" prefix.
189+
loaded = self.model.load_weights(task_weights)
190+
assert len(loaded), "Unable to load RobertaEmbeddingModel"
191+
155192

156193
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
157194
"""A model that uses Roberta to provide embedding functionalities.
@@ -181,20 +218,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
181218

182219
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
183220

184-
self_weights = []
185-
186-
def weight_filter():
187-
for name, weight in weights:
188-
if name.startswith("roberta."):
189-
yield (name[len("roberta."):], weight)
190-
else:
191-
self_weights.append((name, weight))
192-
193-
self.roberta.load_weights(weight_filter())
221+
bert_weights, task_weights = roberta_task_weights_filter(weights)
222+
self.roberta.load_weights(bert_weights)
194223

195224
params_dict = dict(self.named_parameters())
196225

197-
for name, loaded_weight in self_weights:
226+
for name, loaded_weight in task_weights:
198227
if name.startswith("classifier"):
199228
param = params_dict[name]
200229
weight_loader = getattr(param, "weight_loader",

0 commit comments

Comments
 (0)