Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit b7a30d9

Browse files
sangstarRobert Shaw
authored andcommitted
[Frontend] [Core] feat: Add model loading using tensorizer (vllm-project#3476)
1 parent 023060f commit b7a30d9

20 files changed

+1351
-51
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ steps:
9191
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
9292
parallelism: 4
9393

94+
- label: Tensorizer Test
95+
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
96+
9497
- label: Metrics Test
9598
command: pytest -v -s metrics
9699

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"vllm._C",
8484
"numpy",
8585
"tqdm",
86+
"tensorizer",
8687
]
8788

8889
for mock_target in autodoc_mock_imports:

docs/source/models/engine_args.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
3636

3737
Directory to download and load the weights, default to the default cache dir of huggingface.
3838

39-
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
39+
.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
4040

4141
The format of the model weights to load.
4242

@@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
4545
* "safetensors" will load the weights in the safetensors format.
4646
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
4747
* "dummy" will initialize the weights with random values, mainly for profiling.
48+
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_. See `tensorized_vllm_model.py` in the examples folder to serialize a vLLM model, and for more information. Tensorizer support for vLLM can be installed with `pip install vllm[tensorizer]`.
4849

4950
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
5051

examples/tensorize_vllm_model.py

Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
import argparse
2+
import dataclasses
3+
import os
4+
import time
5+
import uuid
6+
from functools import partial
7+
from typing import Type
8+
9+
import torch
10+
import torch.nn as nn
11+
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
12+
TensorSerializer, stream_io)
13+
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
14+
from transformers import AutoConfig, PretrainedConfig
15+
16+
from vllm.distributed import initialize_model_parallel
17+
from vllm.engine.arg_utils import EngineArgs
18+
from vllm.engine.llm_engine import LLMEngine
19+
from vllm.model_executor.models import ModelRegistry
20+
from vllm.model_executor.tensorizer_loader import TensorizerArgs
21+
22+
# yapf conflicts with isort for this docstring
23+
# yapf: disable
24+
"""
25+
tensorize_vllm_model.py is a script that can be used to serialize and
26+
deserialize vLLM models. These models can be loaded using tensorizer directly
27+
to the GPU extremely quickly. Tensor encryption and decryption is also
28+
supported, although libsodium must be installed to use it. Install
29+
vllm with tensorizer support using `pip install vllm[tensorizer]`.
30+
31+
To serialize a model, you can run something like this:
32+
33+
python tensorize_vllm_model.py \
34+
--model EleutherAI/gpt-j-6B \
35+
--dtype float16 \
36+
serialize \
37+
--serialized-directory s3://my-bucket/ \
38+
--suffix vllm
39+
40+
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
41+
and saves it to your S3 bucket. A local directory can also be used.
42+
43+
You can also encrypt the model weights with a randomly-generated key by
44+
providing a `--keyfile` argument.
45+
46+
To deserialize a model, you can run something like this:
47+
48+
python tensorize_vllm_model.py \
49+
--model EleutherAI/gpt-j-6B \
50+
--dtype float16 \
51+
deserialize \
52+
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
53+
54+
Which downloads the model tensors from your S3 bucket and deserializes them.
55+
To provide S3 credentials, you can provide `--s3-access-key-id` and
56+
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
57+
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
58+
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
59+
60+
61+
You can also provide a `--keyfile` argument to decrypt the model weights if
62+
they were serialized with encryption.
63+
64+
For more information on the available arguments, run
65+
`python tensorize_vllm_model.py --help`.
66+
"""
67+
68+
69+
def parse_args():
70+
parser = argparse.ArgumentParser(
71+
description="An example script that can be used to serialize and "
72+
"deserialize vLLM models. These models "
73+
"can be loaded using tensorizer directly to the GPU "
74+
"extremely quickly. Tensor encryption and decryption is "
75+
"also supported, although libsodium must be installed to "
76+
"use it.")
77+
parser = EngineArgs.add_cli_args(parser)
78+
subparsers = parser.add_subparsers(dest='command')
79+
80+
serialize_parser = subparsers.add_parser(
81+
'serialize', help="Serialize a model to `--serialized-directory`")
82+
83+
serialize_parser.add_argument(
84+
"--suffix",
85+
type=str,
86+
required=False,
87+
help=(
88+
"The suffix to append to the serialized model directory, which is "
89+
"used to construct the location of the serialized model tensors, "
90+
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
91+
"`--suffix` is `v1`, the serialized model tensors will be "
92+
"saved to "
93+
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
94+
"If none is provided, a random UUID will be used."))
95+
serialize_parser.add_argument(
96+
"--serialized-directory",
97+
type=str,
98+
required=True,
99+
help="The directory to serialize the model to. "
100+
"This can be a local directory or S3 URI. The path to where the "
101+
"tensors are saved is a combination of the supplied `dir` and model "
102+
"reference ID. For instance, if `dir` is the serialized directory, "
103+
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
104+
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
105+
"where `suffix` is given by `--suffix` or a random UUID if not "
106+
"provided.")
107+
108+
serialize_parser.add_argument(
109+
"--keyfile",
110+
type=str,
111+
required=False,
112+
help=("Encrypt the model weights with a randomly-generated binary key,"
113+
" and save the key at this path"))
114+
115+
deserialize_parser = subparsers.add_parser(
116+
'deserialize',
117+
help=("Deserialize a model from `--path-to-tensors`"
118+
" to verify it can be loaded and used."))
119+
120+
deserialize_parser.add_argument(
121+
"--path-to-tensors",
122+
type=str,
123+
required=True,
124+
help="The local path or S3 URI to the model tensors to deserialize. ")
125+
126+
deserialize_parser.add_argument(
127+
"--keyfile",
128+
type=str,
129+
required=False,
130+
help=("Path to a binary key to use to decrypt the model weights,"
131+
" if the model was serialized with encryption"))
132+
133+
return parser.parse_args()
134+
135+
136+
def make_model_contiguous(model):
137+
# Ensure tensors are saved in memory contiguously
138+
for param in model.parameters():
139+
param.data = param.data.contiguous()
140+
141+
142+
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
143+
architectures = getattr(config, "architectures", [])
144+
for arch in architectures:
145+
model_cls = ModelRegistry.load_model_cls(arch)
146+
if model_cls is not None:
147+
return model_cls
148+
raise ValueError(
149+
f"Model architectures {architectures} are not supported for now. "
150+
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
151+
152+
153+
def serialize():
154+
155+
eng_args_dict = {f.name: getattr(args, f.name) for f in
156+
dataclasses.fields(EngineArgs)}
157+
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
158+
engine = LLMEngine.from_engine_args(engine_args)
159+
160+
model = (engine.model_executor.driver_worker.
161+
model_runner.model)
162+
163+
encryption_params = EncryptionParams.random() if keyfile else None
164+
if keyfile:
165+
with _write_stream(keyfile) as stream:
166+
stream.write(encryption_params.key)
167+
168+
with _write_stream(model_path) as stream:
169+
serializer = TensorSerializer(stream, encryption=encryption_params)
170+
serializer.write_module(model)
171+
serializer.close()
172+
173+
print("Serialization complete. Model tensors saved to", model_path)
174+
if keyfile:
175+
print("Key saved to", keyfile)
176+
177+
178+
def deserialize():
179+
config = AutoConfig.from_pretrained(model_ref)
180+
181+
with no_init_or_tensor():
182+
model_class = _get_vllm_model_architecture(config)
183+
model = model_class(config)
184+
185+
before_mem = get_mem_usage()
186+
start = time.time()
187+
188+
if keyfile:
189+
with _read_stream(keyfile) as stream:
190+
key = stream.read()
191+
decryption_params = DecryptionParams.from_key(key)
192+
tensorizer_args.deserializer_params['encryption'] = \
193+
decryption_params
194+
195+
with (_read_stream(model_path)) as stream, TensorDeserializer(
196+
stream, **tensorizer_args.deserializer_params) as deserializer:
197+
deserializer.load_into_module(model)
198+
end = time.time()
199+
200+
# Brag about how fast we are.
201+
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
202+
duration = end - start
203+
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
204+
after_mem = get_mem_usage()
205+
print(
206+
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
207+
)
208+
print(f"Memory usage before: {before_mem}")
209+
print(f"Memory usage after: {after_mem}")
210+
211+
return model
212+
213+
214+
args = parse_args()
215+
216+
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
217+
or None)
218+
s3_secret_access_key = (args.s3_secret_access_key
219+
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
220+
221+
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
222+
223+
_read_stream, _write_stream = (partial(
224+
stream_io.open_stream,
225+
mode=mode,
226+
s3_access_key_id=s3_access_key_id,
227+
s3_secret_access_key=s3_secret_access_key,
228+
s3_endpoint=s3_endpoint,
229+
) for mode in ("rb", "wb+"))
230+
231+
model_ref = args.model
232+
233+
model_name = model_ref.split("/")[1]
234+
235+
os.environ["MASTER_ADDR"] = "127.0.0.1"
236+
os.environ["MASTER_PORT"] = "8080"
237+
238+
torch.distributed.init_process_group(world_size=1, rank=0)
239+
initialize_model_parallel()
240+
241+
keyfile = args.keyfile if args.keyfile else None
242+
243+
if args.command == "serialize":
244+
input_dir = args.serialized_directory.rstrip('/')
245+
suffix = args.suffix if args.suffix else uuid.uuid4().hex
246+
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
247+
model_path = f"{base_path}/model.tensors"
248+
serialize()
249+
elif args.command == "deserialize":
250+
tensorizer_args = TensorizerArgs.from_cli_args(args)
251+
model_path = args.path_to_tensors
252+
deserialize()
253+
else:
254+
raise ValueError("Either serialize or deserialize must be specified.")

requirements-cpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33

44
# Dependencies for x86_64 CPUs
55
torch == 2.2.1+cpu
6-
triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error.
6+
triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error.

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ types-setuptools
1414

1515
# testing
1616
pytest
17+
tensorizer==2.9.0a0
1718
pytest-forked
1819
pytest-asyncio
1920
pytest-rerunfailures

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ def get_extra_requirements() -> dict:
428428
install_requires=get_requirements(),
429429
extras_require=get_extra_requirements(),
430430
ext_modules=ext_modules,
431+
extras_require={
432+
"optional": ["tensorizer==2.9.0a1"],
433+
},
431434
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
432435
package_data=package_data,
433436
)

tests/tensorizer/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)