Skip to content

Commit c39ec2a

Browse files
authored
Merge pull request vllm-project#2 from ri938/add_awq_improvements
Add awq improvements
2 parents 5bd5ed6 + 2f97151 commit c39ec2a

File tree

7 files changed

+64
-112
lines changed

7 files changed

+64
-112
lines changed

vllm/awq_quantization/qmodule.py

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -32,67 +32,20 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
3232
self.out_features = out_features
3333
self.w_bit = w_bit
3434
self.group_size = group_size if group_size != -1 else in_features
35+
3536
# quick sanity check (make sure aligment)
3637
assert self.in_features % self.group_size == 0
3738
assert out_features % (32 // self.w_bit) == 0
3839

39-
self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
40-
self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
41-
self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev))
40+
self.register_buffer('qweight', torch.empty((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
41+
self.register_buffer('qzeros', torch.empty((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
42+
self.register_buffer('scales', torch.empty((in_features // self.group_size, out_features), dtype=torch.float16, device=dev))
43+
4244
if bias:
43-
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
45+
self.register_buffer('bias', torch.empty((out_features), dtype=torch.float16, device=dev))
4446
else:
4547
self.bias = None
4648

47-
@classmethod
48-
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
49-
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
50-
if init_only: # just prepare for loading sd
51-
return awq_linear
52-
53-
# need scales and zeros info for real quantization
54-
assert scales is not None and zeros is not None
55-
scale_zeros = zeros * scales
56-
57-
awq_linear.scales = scales.clone().half()
58-
if linear.bias is not None:
59-
awq_linear.bias = linear.bias.clone().half()
60-
61-
pack_num = 32 // awq_linear.w_bit
62-
63-
intweight = []
64-
for idx in range(awq_linear.in_features):
65-
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None])
66-
intweight = torch.cat(intweight, dim=1)
67-
intweight = intweight.t().contiguous()
68-
intweight = intweight.to(dtype=torch.int32)
69-
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
70-
71-
for col in range(intweight.shape[1] // pack_num):
72-
if awq_linear.w_bit == 4:
73-
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
74-
else:
75-
raise NotImplementedError("Only 4-bit are supported for now.")
76-
for i in range(pack_num):
77-
qweight_col = intweight[:, col * pack_num + order_map[i]]
78-
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
79-
awq_linear.qweight = qweight
80-
81-
zeros = zeros.to(dtype=torch.int32)
82-
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device)
83-
84-
for col in range(zeros.shape[1] // pack_num):
85-
if awq_linear.w_bit == 4:
86-
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
87-
else:
88-
raise NotImplementedError("Only 4-bit are supported for now.")
89-
for i in range(pack_num):
90-
qzero_col = zeros[:, col * pack_num + order_map[i]]
91-
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
92-
awq_linear.qzeros = qzeros
93-
94-
return awq_linear
95-
9649
@torch.no_grad()
9750
def forward(self, x):
9851
out_shape = x.shape[:-1] + (self.out_features, )

vllm/config.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,34 @@
1212
_GB = 1 << 30
1313

1414

15+
class QuantizationConfig:
16+
"""Quantization settings
17+
18+
Args:
19+
method: The quantization method to apply
20+
bits: How many bits the linear layers are quantized to
21+
group_size: What size the weights were quantized in groups of
22+
"""
23+
24+
def __init__(
25+
self,
26+
method: str,
27+
bits: Optional[int] = 4,
28+
group_size: Optional[int] = 128
29+
) -> None:
30+
self.method = method
31+
self.bits = bits
32+
self.group_size = group_size
33+
self._verify()
34+
35+
def _verify(self) -> None:
36+
allowed_methods = ['awq']
37+
if self.method not in allowed_methods:
38+
raise ValueError(
39+
f"Unknown quantization method ({self.method})"
40+
f" must be from choice of {allowed_methods}")
41+
42+
1543
class ModelConfig:
1644
"""Configuration for the model.
1745
@@ -31,6 +59,7 @@ class ModelConfig:
3159
will use FP16 precision for FP32 and FP16 models, and BF16 precision
3260
for BF16 models.
3361
seed: Random seed for reproducibility.
62+
quantization_config: Optional quantization settings
3463
"""
3564

3665
def __init__(
@@ -44,6 +73,7 @@ def __init__(
4473
use_dummy_weights: bool,
4574
dtype: str,
4675
seed: int,
76+
quantization_config: Optional[QuantizationConfig] = None
4777
) -> None:
4878
self.model = model
4979
self.tokenizer = tokenizer
@@ -53,6 +83,7 @@ def __init__(
5383
self.use_np_weights = use_np_weights
5484
self.use_dummy_weights = use_dummy_weights
5585
self.seed = seed
86+
self.quantization_config = quantization_config
5687

5788
self.hf_config = get_config(model, trust_remote_code)
5889
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
@@ -86,6 +117,9 @@ def verify_with_parallel_config(
86117
"must be divisible by pipeline parallel size "
87118
f"({pipeline_parallel_size}).")
88119

120+
if self.quantization_config and tensor_parallel_size > 1:
121+
raise NotImplementedError("Quantization does not currently support tensor parallelism")
122+
89123
def get_hidden_size(self) -> int:
90124
return self.hf_config.hidden_size
91125

@@ -140,6 +174,13 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
140174
total_num_hidden_layers = self.hf_config.num_hidden_layers
141175
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
142176

177+
def get_quantization_method(self):
178+
if self.quantization_config is None:
179+
method = None
180+
else:
181+
method = self.quantization_config.method
182+
return method
183+
143184

144185
class CacheConfig:
145186
"""Configuration for the KV cache.
@@ -295,39 +336,3 @@ def _get_and_verify_dtype(
295336
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
296337
f"{compute_capability[0]}.{compute_capability[1]}.")
297338
return torch_dtype
298-
299-
300-
class QuantizationConfig:
301-
"""Quantization settings
302-
303-
Args:
304-
method: The quantization method to apply
305-
bits: How many bits the linear layers are quantized to
306-
group_size: What size the weights were quantized in groups of
307-
"""
308-
309-
def __init__(
310-
self,
311-
method: str,
312-
bits: Optional[int] = 4,
313-
group_size: Optional[int] = 128
314-
) -> None:
315-
self.method = method
316-
self.bits = bits
317-
self.group_size = group_size
318-
319-
self._verify()
320-
321-
def _verify(self) -> None:
322-
allowed_methods = ['awq']
323-
if self.method not in allowed_methods:
324-
raise ValueError(
325-
f"Unknown quantization method ({self.method})"
326-
f" must be from choice of {allowed_methods}")
327-
328-
def verify_with_parallel_config(self, parallel_config: "ParallelConfig") -> None:
329-
tensor_parallel_size = parallel_config.tensor_parallel_size
330-
331-
if self.method is not None and tensor_parallel_size > 1:
332-
raise NotImplementedError(
333-
"Quantization does not currently support tensor parallelism")

vllm/engine/arg_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,12 @@ def create_engine_configs(
152152
self,
153153
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
154154
# Initialize the configs.
155+
quantization_config = QuantizationConfig(self.quantization) if self.quantization else None
155156
model_config = ModelConfig(self.model, self.tokenizer,
156157
self.tokenizer_mode, self.trust_remote_code,
157158
self.download_dir, self.use_np_weights,
158159
self.use_dummy_weights, self.dtype,
159-
self.seed)
160+
self.seed, quantization_config)
160161
cache_config = CacheConfig(self.block_size,
161162
self.gpu_memory_utilization,
162163
self.swap_space)
@@ -166,8 +167,7 @@ def create_engine_configs(
166167
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
167168
self.max_num_seqs,
168169
model_config.get_max_model_len())
169-
quantization_config = QuantizationConfig(self.quantization) if self.quantization else None
170-
return model_config, cache_config, parallel_config, scheduler_config, quantization_config
170+
return model_config, cache_config, parallel_config, scheduler_config
171171

172172

173173
@dataclass

vllm/engine/llm_engine.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
55

66
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
7-
SchedulerConfig, QuantizationConfig)
7+
SchedulerConfig)
88
from vllm.core.scheduler import Scheduler
99
from vllm.engine.arg_utils import EngineArgs
1010
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
@@ -55,7 +55,6 @@ class LLMEngine:
5555
stage_devices: The list of devices for each stage. Each stage is a list
5656
of (rank, node_resource, device) tuples.
5757
log_stats: Whether to log statistics.
58-
quantization_config: Optional settings related to using quantized layers
5958
"""
6059

6160
def __init__(
@@ -64,7 +63,6 @@ def __init__(
6463
cache_config: CacheConfig,
6564
parallel_config: ParallelConfig,
6665
scheduler_config: SchedulerConfig,
67-
quantization_config: Optional[QuantizationConfig],
6866
distributed_init_method: str,
6967
placement_group: Optional["PlacementGroup"],
7068
log_stats: bool,
@@ -80,15 +78,14 @@ def __init__(
8078
f"download_dir={model_config.download_dir!r}, "
8179
f"use_np_weights={model_config.use_np_weights}, "
8280
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
83-
f"quantization_method={getattr(quantization_config, 'method', None)}, "
81+
f"quantization_method={model_config.get_quantization_method()}, "
8482
f"seed={model_config.seed})")
8583
# TODO(woosuk): Print more configs in debug mode.
8684

8785
self.model_config = model_config
8886
self.cache_config = cache_config
8987
self.parallel_config = parallel_config
9088
self.scheduler_config = scheduler_config
91-
self.quantization_config = quantization_config
9289
self.log_stats = log_stats
9390
self._verify_args()
9491

@@ -132,7 +129,6 @@ def _init_workers(self, distributed_init_method: str):
132129
self.scheduler_config,
133130
0,
134131
distributed_init_method,
135-
quantization_config=self.quantization_config
136132
)
137133
self.workers.append(worker)
138134
self._run_workers(
@@ -171,7 +167,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup"):
171167
scheduler_config,
172168
None,
173169
None,
174-
self.quantization_config
175170
))
176171
self._run_workers(
177172
"init_model",
@@ -182,9 +177,6 @@ def _verify_args(self) -> None:
182177
self.model_config.verify_with_parallel_config(self.parallel_config)
183178
self.cache_config.verify_with_parallel_config(self.parallel_config)
184179

185-
if self.quantization_config is not None:
186-
self.quantization_config.verify_with_parallel_config(self.parallel_config)
187-
188180
def _init_cache(self) -> None:
189181
"""Profiles the memory usage and initializes the KV cache."""
190182
# Get the maximum number of blocks that can be allocated on GPU and CPU.

vllm/model_executor/model_loader.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66
from transformers import PretrainedConfig
77

8-
from vllm.config import ModelConfig, QuantizationConfig
8+
from vllm.config import ModelConfig
99
from vllm.model_executor.models import * # pylint: disable=wildcard-import
1010
from vllm.model_executor.weight_utils import initialize_dummy_weights
1111

@@ -39,16 +39,19 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
3939
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
4040

4141

42-
def get_model(model_config: ModelConfig, quantization_config: QuantizationConfig) -> nn.Module:
42+
def _supports_quantization(model_class):
43+
return model_class is LlamaForCausalLM
44+
45+
46+
def get_model(model_config: ModelConfig) -> nn.Module:
4347
model_class = _get_model_architecture(model_config.hf_config)
4448
torch.set_default_dtype(model_config.dtype)
4549

4650
# Create a model instance.
4751
# The weights will be initialized as empty tensors.
4852

49-
# TODO: better way to do this
50-
if model_class is LlamaForCausalLM:
51-
model = model_class(model_config.hf_config, quantization_config)
53+
if _supports_quantization(model_class):
54+
model = model_class(model_config.hf_config, model_config.quantization_config)
5255
else:
5356
model = model_class(model_config.hf_config)
5457

@@ -62,4 +65,5 @@ def get_model(model_config: ModelConfig, quantization_config: QuantizationConfig
6265
model.load_weights(model_config.model, model_config.download_dir,
6366
model_config.use_np_weights)
6467
model = model.cuda()
68+
6569
return model.eval()

vllm/model_executor/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_quantized_layer(in_features, out_features, quant_config):
147147
in_features=in_features,
148148
out_features=out_features,
149149
bias=None,
150-
dev=0 ## TODO: fix this without large spike in memory
150+
dev=torch.cuda.current_device()
151151
)
152152
return layer
153153

vllm/worker/worker.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.distributed
77

88
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
9-
SchedulerConfig, QuantizationConfig)
9+
SchedulerConfig)
1010
from vllm.model_executor import get_model, InputMetadata, set_random_seed
1111
from vllm.model_executor.parallel_utils.parallel_state import (
1212
initialize_model_parallel)
@@ -31,12 +31,10 @@ def __init__(
3131
scheduler_config: SchedulerConfig,
3232
rank: Optional[int] = None,
3333
distributed_init_method: Optional[str] = None,
34-
quantization_config: Optional[QuantizationConfig] = None
3534
) -> None:
3635
self.model_config = model_config
3736
self.parallel_config = parallel_config
3837
self.scheduler_config = scheduler_config
39-
self.quantization_config = quantization_config
4038
self.rank = rank
4139
self.distributed_init_method = distributed_init_method
4240

@@ -66,7 +64,7 @@ def init_model(self):
6664

6765
# Initialize the model.
6866
set_random_seed(self.model_config.seed)
69-
self.model = get_model(self.model_config, self.quantization_config)
67+
self.model = get_model(self.model_config)
7068

7169
@torch.inference_mode()
7270
def profile_num_available_blocks(

0 commit comments

Comments
 (0)