Skip to content

Commit 0433314

Browse files
tjohnson31415DarkLight1337
authored andcommitted
[Bugfix]: serialize config by value for --trust-remote-code (vllm-project#6751)
Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Signed-off-by: Sumit Dubey <[email protected]>
1 parent 83225aa commit 0433314

File tree

4 files changed

+103
-28
lines changed

4 files changed

+103
-28
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,25 @@ class ParallelSetup(NamedTuple):
2828
chunked_prefill: bool
2929

3030

31+
class PPTestOptions(NamedTuple):
32+
multi_node_only: bool
33+
trust_remote_code: bool
34+
tokenizer_mode: Optional[str]
35+
36+
3137
@dataclass
3238
class PPTestSettings:
3339
parallel_setups: List[ParallelSetup]
3440
distributed_backends: List[str]
3541
task: TaskOption
36-
trust_remote_code: bool
37-
tokenizer_mode: Optional[str]
42+
test_options: PPTestOptions
3843

3944
@staticmethod
4045
def detailed(
4146
*,
4247
tp_base: int = 1,
4348
pp_base: int = 2,
49+
multi_node_only: bool = False,
4450
task: TaskOption = "auto",
4551
trust_remote_code: bool = False,
4652
tokenizer_mode: Optional[str] = None,
@@ -70,8 +76,9 @@ def detailed(
7076
],
7177
distributed_backends=["mp", "ray"],
7278
task=task,
73-
trust_remote_code=trust_remote_code,
74-
tokenizer_mode=tokenizer_mode,
79+
test_options=PPTestOptions(multi_node_only=multi_node_only,
80+
trust_remote_code=trust_remote_code,
81+
tokenizer_mode=tokenizer_mode),
7582
)
7683

7784
@staticmethod
@@ -80,6 +87,7 @@ def fast(
8087
tp_base: int = 1,
8188
pp_base: int = 2,
8289
task: TaskOption = "auto",
90+
multi_node_only: bool = False,
8391
trust_remote_code: bool = False,
8492
tokenizer_mode: Optional[str] = None,
8593
):
@@ -92,15 +100,18 @@ def fast(
92100
],
93101
distributed_backends=["mp"],
94102
task=task,
95-
trust_remote_code=trust_remote_code,
96-
tokenizer_mode=tokenizer_mode,
103+
test_options=PPTestOptions(multi_node_only=multi_node_only,
104+
trust_remote_code=trust_remote_code,
105+
tokenizer_mode=tokenizer_mode),
97106
)
98107

99108
def iter_params(self, model_name: str):
109+
opts = self.test_options
110+
100111
for parallel_setup in self.parallel_setups:
101112
for distributed_backend in self.distributed_backends:
102113
yield (model_name, parallel_setup, distributed_backend,
103-
self.task, self.trust_remote_code, self.tokenizer_mode)
114+
self.task, opts)
104115

105116

106117
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
@@ -110,6 +121,7 @@ def iter_params(self, model_name: str):
110121
GENERATION_MODEL_SETTINGS = {
111122
# [DETAILED TESTS]
112123
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
124+
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501
113125
# [FAST TESTS]
114126
# Uses Llama
115127
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
@@ -151,10 +163,8 @@ def iter_params(self, model_name: str):
151163
"facebook/opt-iml-max-1.3b": PPTestSettings.fast(),
152164
"OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True),
153165
"microsoft/phi-2": PPTestSettings.fast(),
154-
"microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(),
155166
"microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
156-
# FIXME: https://github.com/vllm-project/vllm/issues/8553
157-
# "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
167+
"microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
158168
"adept/persimmon-8b-chat": PPTestSettings.fast(),
159169
"Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
160170
"Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(),
@@ -205,6 +215,7 @@ def iter_params(self, model_name: str):
205215
# [LANGUAGE GENERATION]
206216
"meta-llama/Meta-Llama-3-8B",
207217
"ibm/PowerLM-3b",
218+
"microsoft/Phi-3-mini-4k-instruct",
208219
# [LANGUAGE EMBEDDING]
209220
"intfloat/e5-mistral-7b-instruct",
210221
"BAAI/bge-multilingual-gemma2",
@@ -220,19 +231,21 @@ def _compare_tp(
220231
parallel_setup: ParallelSetup,
221232
distributed_backend: str,
222233
task: TaskOption,
223-
trust_remote_code: bool,
224-
tokenizer_mode: Optional[str],
234+
test_options: PPTestOptions,
225235
num_gpus_available: int,
226236
*,
227-
method: Literal["generate", "encode"] = "encode",
237+
method: Literal["generate", "encode"],
228238
):
229239
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
240+
multi_node_only, trust_remote_code, tokenizer_mode = test_options
230241

231242
if num_gpus_available < tp_size * pp_size:
232243
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
233244
if VLLM_MULTI_NODE and distributed_backend == "mp":
234245
pytest.skip("Skipping multi-node pipeline parallel test for "
235246
"multiprocessing distributed backend")
247+
if multi_node_only and not VLLM_MULTI_NODE:
248+
pytest.skip("Not in multi-node setting")
236249

237250
common_args = [
238251
# use half precision for speed and memory savings in CI environment
@@ -307,7 +320,7 @@ def _compare_tp(
307320

308321
@pytest.mark.parametrize(
309322
("model_name", "parallel_setup", "distributed_backend", "task",
310-
"trust_remote_code", "tokenizer_mode"),
323+
"test_options"),
311324
[
312325
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
313326
for params in settings.iter_params(model_name)
@@ -320,23 +333,21 @@ def test_tp_language_generation(
320333
parallel_setup: ParallelSetup,
321334
distributed_backend: str,
322335
task: TaskOption,
323-
trust_remote_code: bool,
324-
tokenizer_mode: Optional[str],
336+
test_options: PPTestOptions,
325337
num_gpus_available,
326338
):
327339
_compare_tp(model_name,
328340
parallel_setup,
329341
distributed_backend,
330342
task,
331-
trust_remote_code,
332-
tokenizer_mode,
343+
test_options,
333344
num_gpus_available,
334345
method="generate")
335346

336347

337348
@pytest.mark.parametrize(
338349
("model_name", "parallel_setup", "distributed_backend", "task",
339-
"trust_remote_code", "tokenizer_mode"),
350+
"test_options"),
340351
[
341352
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
342353
for params in settings.iter_params(model_name)
@@ -349,23 +360,21 @@ def test_tp_language_embedding(
349360
parallel_setup: ParallelSetup,
350361
distributed_backend: str,
351362
task: TaskOption,
352-
trust_remote_code: bool,
353-
tokenizer_mode: Optional[str],
363+
test_options: PPTestOptions,
354364
num_gpus_available,
355365
):
356366
_compare_tp(model_name,
357367
parallel_setup,
358368
distributed_backend,
359369
task,
360-
trust_remote_code,
361-
tokenizer_mode,
370+
test_options,
362371
num_gpus_available,
363372
method="encode")
364373

365374

366375
@pytest.mark.parametrize(
367376
("model_name", "parallel_setup", "distributed_backend", "task",
368-
"trust_remote_code", "tokenizer_mode"),
377+
"test_options"),
369378
[
370379
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
371380
for params in settings.iter_params(model_name)
@@ -378,15 +387,13 @@ def test_tp_multimodal_generation(
378387
parallel_setup: ParallelSetup,
379388
distributed_backend: str,
380389
task: TaskOption,
381-
trust_remote_code: bool,
382-
tokenizer_mode: Optional[str],
390+
test_options: PPTestOptions,
383391
num_gpus_available,
384392
):
385393
_compare_tp(model_name,
386394
parallel_setup,
387395
distributed_backend,
388396
task,
389-
trust_remote_code,
390-
tokenizer_mode,
397+
test_options,
391398
num_gpus_available,
392399
method="generate")

vllm/engine/arg_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from vllm.executor.executor_base import ExecutorBase
1717
from vllm.logger import init_logger
1818
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
19+
from vllm.transformers_utils.config import (
20+
maybe_register_config_serialize_by_value)
1921
from vllm.transformers_utils.utils import check_gguf_file
2022
from vllm.utils import FlexibleArgumentParser
2123

@@ -924,6 +926,8 @@ def create_engine_config(self) -> EngineConfig:
924926
"supported for multimodal models and has been disabled.")
925927
self.enable_prefix_caching = False
926928

929+
maybe_register_config_serialize_by_value(self.trust_remote_code)
930+
927931
cache_config = CacheConfig(
928932
# neuron needs block_size = max_model_len
929933
block_size=self.block_size if self.device != "neuron" else

vllm/transformers_utils/config.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,68 @@ def get_config(
232232
return config
233233

234234

235+
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
236+
"""Try to register HF model configuration class to serialize by value
237+
238+
With trust_remote_code, the config class is typically an instance of a
239+
custom class imported from the HF modules cache. The class will not be
240+
importable in spawned workers by default (and won't exist at all on
241+
other nodes), which breaks serialization of the config.
242+
243+
In this function we tell the cloudpickle serialization library to pass
244+
instances of these generated classes by value instead of by reference,
245+
i.e. the class definition is serialized along with its data so that the
246+
class module does not need to be importable on the receiving end. This
247+
registration only works if the modules cache has already been
248+
initialized.
249+
250+
251+
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
252+
"""
253+
if not trust_remote_code:
254+
return
255+
256+
try:
257+
import transformers_modules
258+
except ImportError:
259+
logger.debug("Could not import transformers_modules used for remote"
260+
" code. If remote code is not needed remove"
261+
" `--trust-remote-code`.")
262+
return
263+
264+
try:
265+
import cloudpickle
266+
cloudpickle.register_pickle_by_value(transformers_modules)
267+
268+
# ray vendors its own version of cloudpickle
269+
from vllm.executor.ray_utils import ray
270+
if ray:
271+
ray.cloudpickle.register_pickle_by_value(transformers_modules)
272+
273+
# multiprocessing uses pickle to serialize arguments when using spawn
274+
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
275+
# that contain instances of the custom config class to avoid
276+
# serialization problems if the generated module (and model) has a `.`
277+
# in its name
278+
import multiprocessing
279+
import pickle
280+
281+
from vllm.config import ModelConfig
282+
283+
def _reduce_modelconfig(mc: ModelConfig):
284+
return (pickle.loads, (cloudpickle.dumps(mc), ))
285+
286+
multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
287+
288+
except Exception as e:
289+
logger.warning(
290+
"Unable to register remote classes used by"
291+
" trust_remote_code with by-value serialization. This may"
292+
" lead to a later error. If remote code is not needed"
293+
" remove `--trust-remote-code`",
294+
exc_info=e)
295+
296+
235297
def load_params_config(model, revision) -> PretrainedConfig:
236298
# This function loads a params.json config which
237299
# should be used when loading models in mistral format

vllm/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
968968
return [item for sublist in lists for item in sublist]
969969

970970

971+
# TODO: This function can be removed if transformer_modules classes are
972+
# serialized by value when communicating between processes
971973
def init_cached_hf_modules() -> None:
972974
"""
973975
Lazy initialization of the Hugging Face modules.

0 commit comments

Comments
 (0)