From fee1dc42d379441dd1b4cf60d60591c4a78e249e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Apr 2024 10:30:09 -0700 Subject: [PATCH 1/9] Turn `linear_weights` attribute into a dynamic property to avoid double references --- vllm/model_executor/layers/linear.py | 39 +++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3d4d1789db2..3cc1046c4434 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -112,12 +112,14 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( + linear_weights = self.linear_method.create_weights( self.input_size, self.output_size, self.input_size, self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): + self._linear_weights_names = [] + for name, weight in linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) + self._linear_weights_names.append(name) if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -125,6 +127,13 @@ def __init__( else: self.register_parameter("bias", None) + @property + def linear_weights(self) -> Dict[str, torch.Tensor]: + return { + name: getattr(self, name) + for name in self._linear_weights_names + } + def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None output = self.linear_method.apply_weights(self.linear_weights, x, bias) @@ -178,13 +187,15 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( + linear_weights = self.linear_method.create_weights( self.input_size, self.output_size_per_partition, self.input_size, self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): + self._linear_weights_names = [] + for name, weight in linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self._linear_weights_names.append(name) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -196,6 +207,13 @@ def __init__( else: self.register_parameter("bias", None) + @property + def linear_weights(self) -> Dict[str, torch.Tensor]: + return { + name: getattr(self, name) + for name in self._linear_weights_names + } + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) @@ -524,13 +542,15 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_weights = self.linear_method.create_weights( + linear_weights = self.linear_method.create_weights( self.input_size_per_partition, self.output_size, self.input_size, self.output_size, self.params_dtype) - for name, weight in self.linear_weights.items(): + self._linear_weights_names = [] + for name, weight in linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + self._linear_weights_names.append(name) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -546,6 +566,13 @@ def __init__( else: self.register_parameter("bias", None) + @property + def linear_weights(self) -> Dict[str, torch.Tensor]: + return { + name: getattr(self, name) + for name in self._linear_weights_names + } + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) From 3b46ae5e73062aefd02c17e033aaef8e9a035140 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Apr 2024 11:17:39 -0700 Subject: [PATCH 2/9] Fix --- vllm/model_executor/layers/linear.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3cc1046c4434..4fd06ab22ced 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -119,7 +119,7 @@ def __init__( for name, weight in linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) - self._linear_weights_names.append(name) + self._linear_weights_names.append(name) if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -195,7 +195,7 @@ def __init__( if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) set_weight_attrs(weight, {"weight_loader": self.weight_loader}) - self._linear_weights_names.append(name) + self._linear_weights_names.append(name) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -550,7 +550,7 @@ def __init__( if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) set_weight_attrs(weight, {"weight_loader": self.weight_loader}) - self._linear_weights_names.append(name) + self._linear_weights_names.append(name) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " From dcadfbd2f86e9840d258c00007ad35dcb4ed44e9 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Apr 2024 11:19:17 -0700 Subject: [PATCH 3/9] Fix --- vllm/model_executor/layers/linear.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4fd06ab22ced..51534837a6bb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -119,6 +119,8 @@ def __init__( for name, weight in linear_weights.items(): if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) + else: + setattr(self, name, weight) self._linear_weights_names.append(name) if bias: self.bias = Parameter( @@ -195,6 +197,8 @@ def __init__( if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + else: + setattr(self, name, weight) self._linear_weights_names.append(name) if bias: self.bias = Parameter( @@ -550,6 +554,8 @@ def __init__( if isinstance(weight, torch.Tensor): self.register_parameter(name, weight) set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + else: + setattr(self, name, weight) self._linear_weights_names.append(name) if not reduce_results and (bias and not skip_bias_add): From 09ef14c266121ff8624ee291236d0f7052dde85a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Apr 2024 14:21:05 -0700 Subject: [PATCH 4/9] Redo --- tests/kernels/test_moe.py | 2 +- vllm/lora/layers.py | 12 +-- vllm/model_executor/layers/linear.py | 98 ++++++------------- .../model_executor/layers/quantization/awq.py | 29 +++--- .../layers/quantization/gptq.py | 47 +++++---- .../layers/quantization/marlin.py | 23 +++-- .../layers/quantization/squeezellm.py | 24 ++--- 7 files changed, 103 insertions(+), 132 deletions(-) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index affbbfb4aa94..046f11d957bd 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype): ).cuda() # Load the weights - vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data + vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 050501475395..99500bc6f36f 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -368,7 +368,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -402,10 +402,6 @@ def forward(self, input_): if self.base_layer.skip_bias_add else None) return output, output_bias - @property - def linear_weights(self): - return self.base_layer.linear_weights - @classmethod def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, @@ -505,7 +501,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -746,7 +742,7 @@ def set_lora( def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x, bias) + self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -838,7 +834,7 @@ def set_mapping( def apply_weights(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( - self.base_layer.linear_weights, x) + self.base_layer, x) _apply_lora( x, self.lora_a_stacked, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 51534837a6bb..a95c3bf4e324 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn.functional as F @@ -29,16 +29,17 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): """Create weights for a linear layer.""" raise NotImplementedError @abstractmethod def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights to the input tensor.""" @@ -56,22 +57,24 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - return {"weight": weight} + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, torch.Tensor], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - weight = weights["weight"] + weight = layer.weight if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias @@ -112,16 +115,9 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.input_size, - self.output_size, self.params_dtype) - self._linear_weights_names = [] - for name, weight in linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - else: - setattr(self, name, weight) - self._linear_weights_names.append(name) + self.linear_method.create_weights(self, self.input_size, + self.output_size, self.input_size, + self.output_size, self.params_dtype) if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -129,16 +125,9 @@ def __init__( else: self.register_parameter("bias", None) - @property - def linear_weights(self) -> Dict[str, torch.Tensor]: - return { - name: getattr(self, name) - for name in self._linear_weights_names - } - def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self.linear_weights, x, bias) + output = self.linear_method.apply_weights(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -189,17 +178,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.input_size, - self.output_size, self.params_dtype) - self._linear_weights_names = [] - for name, weight in linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) - else: - setattr(self, name, weight) - self._linear_weights_names.append(name) + self.linear_method.create_weights(self, + self.input_size, + self.output_size_per_partition, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -211,13 +196,6 @@ def __init__( else: self.register_parameter("bias", None) - @property - def linear_weights(self) -> Dict[str, torch.Tensor]: - return { - name: getattr(self, name) - for name in self._linear_weights_names - } - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) @@ -234,8 +212,7 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_, bias) + output_parallel = self.linear_method.apply_weights(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -546,17 +523,9 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.input_size, - self.output_size, self.params_dtype) - self._linear_weights_names = [] - for name, weight in linear_weights.items(): - if isinstance(weight, torch.Tensor): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) - else: - setattr(self, name, weight) - self._linear_weights_names.append(name) + self.linear_method.create_weights(self, self.input_size_per_partition, + self.output_size, self.input_size, + self.output_size, self.params_dtype) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -572,13 +541,6 @@ def __init__( else: self.register_parameter("bias", None) - @property - def linear_weights(self) -> Dict[str, torch.Tensor]: - return { - name: getattr(self, name) - for name in self._linear_weights_names - } - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) @@ -603,7 +565,7 @@ def forward(self, input_): # Matrix multiply. output_parallel = self.linear_method.apply_weights( - self.linear_weights, input_parallel) + self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 2caef5f1ebf5..9648c94da9c0 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, }) - return { - "qweight": qweight, - "qzeros": qzeros, - "scales": scales, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - scales = weights["scales"] - qzeros = weights["qzeros"] + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) @@ -163,5 +166,5 @@ def apply_weights(self, out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 53baf710ed81..4b0ebb19a576 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( @@ -179,37 +181,40 @@ def create_weights( "input_dim": scale_and_zero_input_dim, "output_dim": 1, }) - return { - "qweight": qweight, - "g_idx": g_idx, - "qzeros": qzeros, - "scales": scales, - "exllama_state": exllama_state, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("g_idx", g_idx) + set_weight_attrs(g_idx, extra_weight_attrs) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + layer.exllama_state = exllama_state def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] + qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: - weights["g_idx"] = torch.argsort(weights["g_idx"]).to( - torch.int) + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + layer.g_idx.data = torch.empty((0, ), + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - output = ops.gptq_gemm(reshaped_x, weights["qweight"], - weights["qzeros"], weights["scales"], - weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY, + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: - output = output + bias + output.add_(bias) return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 784229878edf..ed59c3341531 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig): def create_weights( self, + layer: torch.nn.Module, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, - ) -> Dict[str, Any]: + **extra_weight_attrs, + ): del output_size # Unused. if params_dtype != torch.float16: @@ -187,21 +189,22 @@ def create_weights( dtype=torch.int), requires_grad=False) - return { - "B": qweight, - "s": scales, - "workspace": workspace, - } + layer.register_parameter("B", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("s", scales) + set_weight_attrs(scales, extra_weight_attrs) + layer.register_parameter("workspace", workspace) + set_weight_attrs(workspace, extra_weight_attrs) def apply_weights( self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = weights["B"] - scales = weights["s"] - workspace = weights["workspace"] + qweight = layer.B + scales = layer.s + workspace = layer.workspace x_2d = x.view(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index ed25455e6ec1..444612bbe0aa 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -103,17 +104,18 @@ def create_weights(self, input_size_per_partition: int, set_weight_attrs(lookup_table, { "output_dim": 0, }) - return { - "qweight": qweight, - "lookup_table": lookup_table, - } + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("lookup_table", lookup_table) + set_weight_attrs(lookup_table, extra_weight_attrs) def apply_weights(self, - weights: Dict[str, Any], + layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - lookup_table = weights["lookup_table"] + qweight = layer.qweight + lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): @@ -126,5 +128,5 @@ def apply_weights(self, ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: - out = out + bias + out.add_(bias) return out.reshape(out_shape) From 888ba9c43eca9e1ada1f85e1f20619fd982ce97e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Apr 2024 14:23:18 -0700 Subject: [PATCH 5/9] Fix --- vllm/model_executor/layers/linear.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a95c3bf4e324..fd1ef71058e1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -523,9 +523,13 @@ def __init__( if linear_method is None: linear_method = UnquantizedLinearMethod() self.linear_method = linear_method - self.linear_method.create_weights(self, self.input_size_per_partition, - self.output_size, self.input_size, - self.output_size, self.params_dtype) + self.linear_method.create_weights(self, + self.input_size_per_partition, + self.output_size, + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " From 46c2f292224d5a2fcc08cb1c87f40bfe36a486da Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Apr 2024 14:25:49 -0700 Subject: [PATCH 6/9] Docstring --- vllm/model_executor/layers/linear.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fd1ef71058e1..4105a87c1eb7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -34,7 +34,9 @@ def create_weights(self, layer: torch.nn.Module, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - """Create weights for a linear layer.""" + """Create weights for a linear layer. + + The weights will be set as attributes of the layer.""" raise NotImplementedError @abstractmethod @@ -42,7 +44,9 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - """Apply the weights to the input tensor.""" + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" raise NotImplementedError From a740d2b3525ed9257ec52473a32e723a51a37513 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 10 Apr 2024 14:32:27 -0700 Subject: [PATCH 7/9] Meta fix --- csrc/quantization/gptq/q_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 655158e38f55..cc56649917a8 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -2067,7 +2067,7 @@ void gptq_shuffle const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(), q_weight.size(0) * 32 / bit, q_weight.size(1), bit From e20cdc1671830e0bdd295cc24de12868867e1c1e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 11:04:56 -0700 Subject: [PATCH 8/9] Review comment --- vllm/model_executor/layers/quantization/gptq.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 4b0ebb19a576..37163606f9e4 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -80,12 +80,15 @@ class ExllamaState(Enum): class GPTQLinearMethod(LinearMethodBase): """Linear method for GPTQ. + Note this linear method holds its own state. + Args: quant_config: The GPTQ quantization config. """ def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config + self.exllama_state = ExllamaState.UNINITIALIZED def create_weights( self, @@ -191,7 +194,7 @@ def create_weights( layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) - layer.exllama_state = exllama_state + self.exllama_state = exllama_state def apply_weights(self, layer: torch.nn.Module, @@ -202,18 +205,18 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if layer.exllama_state == ExllamaState.UNINITIALIZED: + if self.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: layer.g_idx.data = torch.empty((0, ), device=layer.g_idx.device) - layer.exllama_state = ExllamaState.READY + self.exllama_state = ExllamaState.READY ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, + self.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: output.add_(bias) From 59b4fb67735e1f78aaa77ef322222ed05904bdbb Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 12:09:46 -0700 Subject: [PATCH 9/9] Revert "Review comment" This reverts commit e20cdc1671830e0bdd295cc24de12868867e1c1e. --- vllm/model_executor/layers/quantization/gptq.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 37163606f9e4..4b0ebb19a576 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -80,15 +80,12 @@ class ExllamaState(Enum): class GPTQLinearMethod(LinearMethodBase): """Linear method for GPTQ. - Note this linear method holds its own state. - Args: quant_config: The GPTQ quantization config. """ def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config - self.exllama_state = ExllamaState.UNINITIALIZED def create_weights( self, @@ -194,7 +191,7 @@ def create_weights( layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) - self.exllama_state = exllama_state + layer.exllama_state = exllama_state def apply_weights(self, layer: torch.nn.Module, @@ -205,18 +202,18 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if self.exllama_state == ExllamaState.UNINITIALIZED: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: layer.g_idx.data = torch.empty((0, ), device=layer.g_idx.device) - self.exllama_state = ExllamaState.READY + layer.exllama_state = ExllamaState.READY ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, layer.scales, layer.g_idx, - self.exllama_state == ExllamaState.READY, + layer.exllama_state == ExllamaState.READY, self.quant_config.weight_bits) if bias is not None: output.add_(bias)