|
| 1 | +import argparse |
| 2 | +import dataclasses |
| 3 | +import os |
| 4 | +import time |
| 5 | +import uuid |
| 6 | +from functools import partial |
| 7 | +from typing import Type |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, |
| 12 | + TensorSerializer, stream_io) |
| 13 | +from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor |
| 14 | +from transformers import AutoConfig, PretrainedConfig |
| 15 | + |
| 16 | +from vllm.distributed import initialize_model_parallel |
| 17 | +from vllm.engine.arg_utils import EngineArgs |
| 18 | +from vllm.engine.llm_engine import LLMEngine |
| 19 | +from vllm.model_executor.models import ModelRegistry |
| 20 | +from vllm.model_executor.tensorizer_loader import TensorizerArgs |
| 21 | + |
| 22 | +# yapf conflicts with isort for this docstring |
| 23 | +# yapf: disable |
| 24 | +""" |
| 25 | +tensorize_vllm_model.py is a script that can be used to serialize and |
| 26 | +deserialize vLLM models. These models can be loaded using tensorizer directly |
| 27 | +to the GPU extremely quickly. Tensor encryption and decryption is also |
| 28 | +supported, although libsodium must be installed to use it. Install |
| 29 | +vllm with tensorizer support using `pip install vllm[tensorizer]`. |
| 30 | +
|
| 31 | +To serialize a model, you can run something like this: |
| 32 | +
|
| 33 | +python tensorize_vllm_model.py \ |
| 34 | + --model EleutherAI/gpt-j-6B \ |
| 35 | + --dtype float16 \ |
| 36 | + serialize \ |
| 37 | + --serialized-directory s3://my-bucket/ \ |
| 38 | + --suffix vllm |
| 39 | + |
| 40 | +Which downloads the model from HuggingFace, loads it into vLLM, serializes it, |
| 41 | +and saves it to your S3 bucket. A local directory can also be used. |
| 42 | +
|
| 43 | +You can also encrypt the model weights with a randomly-generated key by |
| 44 | +providing a `--keyfile` argument. |
| 45 | +
|
| 46 | +To deserialize a model, you can run something like this: |
| 47 | +
|
| 48 | +python tensorize_vllm_model.py \ |
| 49 | + --model EleutherAI/gpt-j-6B \ |
| 50 | + --dtype float16 \ |
| 51 | + deserialize \ |
| 52 | + --path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors |
| 53 | +
|
| 54 | +Which downloads the model tensors from your S3 bucket and deserializes them. |
| 55 | +To provide S3 credentials, you can provide `--s3-access-key-id` and |
| 56 | +`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script, |
| 57 | +the OpenAI entrypoint, as arguments for LLM(), or as environment variables |
| 58 | +in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`. |
| 59 | +
|
| 60 | +
|
| 61 | +You can also provide a `--keyfile` argument to decrypt the model weights if |
| 62 | +they were serialized with encryption. |
| 63 | +
|
| 64 | +For more information on the available arguments, run |
| 65 | +`python tensorize_vllm_model.py --help`. |
| 66 | +""" |
| 67 | + |
| 68 | + |
| 69 | +def parse_args(): |
| 70 | + parser = argparse.ArgumentParser( |
| 71 | + description="An example script that can be used to serialize and " |
| 72 | + "deserialize vLLM models. These models " |
| 73 | + "can be loaded using tensorizer directly to the GPU " |
| 74 | + "extremely quickly. Tensor encryption and decryption is " |
| 75 | + "also supported, although libsodium must be installed to " |
| 76 | + "use it.") |
| 77 | + parser = EngineArgs.add_cli_args(parser) |
| 78 | + subparsers = parser.add_subparsers(dest='command') |
| 79 | + |
| 80 | + serialize_parser = subparsers.add_parser( |
| 81 | + 'serialize', help="Serialize a model to `--serialized-directory`") |
| 82 | + |
| 83 | + serialize_parser.add_argument( |
| 84 | + "--suffix", |
| 85 | + type=str, |
| 86 | + required=False, |
| 87 | + help=( |
| 88 | + "The suffix to append to the serialized model directory, which is " |
| 89 | + "used to construct the location of the serialized model tensors, " |
| 90 | + "e.g. if `--serialized-directory` is `s3://my-bucket/` and " |
| 91 | + "`--suffix` is `v1`, the serialized model tensors will be " |
| 92 | + "saved to " |
| 93 | + "`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. " |
| 94 | + "If none is provided, a random UUID will be used.")) |
| 95 | + serialize_parser.add_argument( |
| 96 | + "--serialized-directory", |
| 97 | + type=str, |
| 98 | + required=True, |
| 99 | + help="The directory to serialize the model to. " |
| 100 | + "This can be a local directory or S3 URI. The path to where the " |
| 101 | + "tensors are saved is a combination of the supplied `dir` and model " |
| 102 | + "reference ID. For instance, if `dir` is the serialized directory, " |
| 103 | + "and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will " |
| 104 | + "be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, " |
| 105 | + "where `suffix` is given by `--suffix` or a random UUID if not " |
| 106 | + "provided.") |
| 107 | + |
| 108 | + serialize_parser.add_argument( |
| 109 | + "--keyfile", |
| 110 | + type=str, |
| 111 | + required=False, |
| 112 | + help=("Encrypt the model weights with a randomly-generated binary key," |
| 113 | + " and save the key at this path")) |
| 114 | + |
| 115 | + deserialize_parser = subparsers.add_parser( |
| 116 | + 'deserialize', |
| 117 | + help=("Deserialize a model from `--path-to-tensors`" |
| 118 | + " to verify it can be loaded and used.")) |
| 119 | + |
| 120 | + deserialize_parser.add_argument( |
| 121 | + "--path-to-tensors", |
| 122 | + type=str, |
| 123 | + required=True, |
| 124 | + help="The local path or S3 URI to the model tensors to deserialize. ") |
| 125 | + |
| 126 | + deserialize_parser.add_argument( |
| 127 | + "--keyfile", |
| 128 | + type=str, |
| 129 | + required=False, |
| 130 | + help=("Path to a binary key to use to decrypt the model weights," |
| 131 | + " if the model was serialized with encryption")) |
| 132 | + |
| 133 | + return parser.parse_args() |
| 134 | + |
| 135 | + |
| 136 | +def make_model_contiguous(model): |
| 137 | + # Ensure tensors are saved in memory contiguously |
| 138 | + for param in model.parameters(): |
| 139 | + param.data = param.data.contiguous() |
| 140 | + |
| 141 | + |
| 142 | +def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: |
| 143 | + architectures = getattr(config, "architectures", []) |
| 144 | + for arch in architectures: |
| 145 | + model_cls = ModelRegistry.load_model_cls(arch) |
| 146 | + if model_cls is not None: |
| 147 | + return model_cls |
| 148 | + raise ValueError( |
| 149 | + f"Model architectures {architectures} are not supported for now. " |
| 150 | + f"Supported architectures: {ModelRegistry.get_supported_archs()}") |
| 151 | + |
| 152 | + |
| 153 | +def serialize(): |
| 154 | + |
| 155 | + eng_args_dict = {f.name: getattr(args, f.name) for f in |
| 156 | + dataclasses.fields(EngineArgs)} |
| 157 | + engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict)) |
| 158 | + engine = LLMEngine.from_engine_args(engine_args) |
| 159 | + |
| 160 | + model = (engine.model_executor.driver_worker. |
| 161 | + model_runner.model) |
| 162 | + |
| 163 | + encryption_params = EncryptionParams.random() if keyfile else None |
| 164 | + if keyfile: |
| 165 | + with _write_stream(keyfile) as stream: |
| 166 | + stream.write(encryption_params.key) |
| 167 | + |
| 168 | + with _write_stream(model_path) as stream: |
| 169 | + serializer = TensorSerializer(stream, encryption=encryption_params) |
| 170 | + serializer.write_module(model) |
| 171 | + serializer.close() |
| 172 | + |
| 173 | + print("Serialization complete. Model tensors saved to", model_path) |
| 174 | + if keyfile: |
| 175 | + print("Key saved to", keyfile) |
| 176 | + |
| 177 | + |
| 178 | +def deserialize(): |
| 179 | + config = AutoConfig.from_pretrained(model_ref) |
| 180 | + |
| 181 | + with no_init_or_tensor(): |
| 182 | + model_class = _get_vllm_model_architecture(config) |
| 183 | + model = model_class(config) |
| 184 | + |
| 185 | + before_mem = get_mem_usage() |
| 186 | + start = time.time() |
| 187 | + |
| 188 | + if keyfile: |
| 189 | + with _read_stream(keyfile) as stream: |
| 190 | + key = stream.read() |
| 191 | + decryption_params = DecryptionParams.from_key(key) |
| 192 | + tensorizer_args.deserializer_params['encryption'] = \ |
| 193 | + decryption_params |
| 194 | + |
| 195 | + with (_read_stream(model_path)) as stream, TensorDeserializer( |
| 196 | + stream, **tensorizer_args.deserializer_params) as deserializer: |
| 197 | + deserializer.load_into_module(model) |
| 198 | + end = time.time() |
| 199 | + |
| 200 | + # Brag about how fast we are. |
| 201 | + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) |
| 202 | + duration = end - start |
| 203 | + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) |
| 204 | + after_mem = get_mem_usage() |
| 205 | + print( |
| 206 | + f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s" |
| 207 | + ) |
| 208 | + print(f"Memory usage before: {before_mem}") |
| 209 | + print(f"Memory usage after: {after_mem}") |
| 210 | + |
| 211 | + return model |
| 212 | + |
| 213 | + |
| 214 | +args = parse_args() |
| 215 | + |
| 216 | +s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID") |
| 217 | + or None) |
| 218 | +s3_secret_access_key = (args.s3_secret_access_key |
| 219 | + or os.environ.get("S3_SECRET_ACCESS_KEY") or None) |
| 220 | + |
| 221 | +s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None) |
| 222 | + |
| 223 | +_read_stream, _write_stream = (partial( |
| 224 | + stream_io.open_stream, |
| 225 | + mode=mode, |
| 226 | + s3_access_key_id=s3_access_key_id, |
| 227 | + s3_secret_access_key=s3_secret_access_key, |
| 228 | + s3_endpoint=s3_endpoint, |
| 229 | +) for mode in ("rb", "wb+")) |
| 230 | + |
| 231 | +model_ref = args.model |
| 232 | + |
| 233 | +model_name = model_ref.split("/")[1] |
| 234 | + |
| 235 | +os.environ["MASTER_ADDR"] = "127.0.0.1" |
| 236 | +os.environ["MASTER_PORT"] = "8080" |
| 237 | + |
| 238 | +torch.distributed.init_process_group(world_size=1, rank=0) |
| 239 | +initialize_model_parallel() |
| 240 | + |
| 241 | +keyfile = args.keyfile if args.keyfile else None |
| 242 | + |
| 243 | +if args.command == "serialize": |
| 244 | + input_dir = args.serialized_directory.rstrip('/') |
| 245 | + suffix = args.suffix if args.suffix else uuid.uuid4().hex |
| 246 | + base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" |
| 247 | + model_path = f"{base_path}/model.tensors" |
| 248 | + serialize() |
| 249 | +elif args.command == "deserialize": |
| 250 | + tensorizer_args = TensorizerArgs.from_cli_args(args) |
| 251 | + model_path = args.path_to_tensors |
| 252 | + deserialize() |
| 253 | +else: |
| 254 | + raise ValueError("Either serialize or deserialize must be specified.") |
0 commit comments