|
8 | 8 | from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase |
9 | 9 | from vllm.model_executor.layers.quantization.base_config import ( |
10 | 10 | QuantizationConfig) |
11 | | -from vllm.model_executor.utils import set_weight_attrs |
| 11 | +from vllm.model_executor.parameter import (BasevLLMParameter, |
| 12 | + ChannelQuantScaleParameter, |
| 13 | + GroupQuantScaleParameter, |
| 14 | + PackedvLLMParameter) |
12 | 15 |
|
13 | 16 | logger = init_logger(__name__) |
14 | 17 |
|
@@ -133,6 +136,7 @@ def create_weights( |
133 | 136 | params_dtype: torch.dtype, |
134 | 137 | **extra_weight_attrs, |
135 | 138 | ): |
| 139 | + weight_loader = extra_weight_attrs["weight_loader"] |
136 | 140 | if params_dtype != torch.float16: |
137 | 141 | raise ValueError( |
138 | 142 | f"The params dtype must be float16, but got {params_dtype}") |
@@ -170,90 +174,74 @@ def create_weights( |
170 | 174 | "Each permutation group must reside on the same gpu") |
171 | 175 |
|
172 | 176 | # Quantized 4Bit weights packed into Int32. |
173 | | - qweight = Parameter( |
174 | | - torch.empty( |
| 177 | + qweight = PackedvLLMParameter( |
| 178 | + data=torch.empty( |
175 | 179 | input_size_per_partition // self.quant_config.tile_size, |
176 | 180 | output_size_per_partition * self.quant_config.tile_size // |
177 | 181 | self.quant_config.pack_factor, |
178 | 182 | device="cuda", |
179 | 183 | dtype=torch.int32, |
180 | 184 | ), |
181 | | - requires_grad=False, |
182 | | - ) |
183 | | - set_weight_attrs( |
184 | | - qweight, |
185 | | - { |
186 | | - "input_dim": 0, |
187 | | - "output_dim": 1, |
188 | | - "packed_dim": 1, |
189 | | - "pack_factor": self.quant_config.pack_factor, |
190 | | - "marlin_tile_size": self.quant_config.tile_size, |
191 | | - }, |
192 | | - ) |
193 | | - |
194 | | - s_channel = Parameter( |
195 | | - torch.empty( |
196 | | - 1, |
197 | | - output_size_per_partition, |
198 | | - device="cuda", |
199 | | - dtype=torch.float, |
200 | | - ), |
201 | | - requires_grad=False, |
202 | | - ) |
203 | | - set_weight_attrs( |
204 | | - s_channel, |
205 | | - { |
206 | | - "input_dim": None, |
207 | | - "output_dim": 1, |
208 | | - }, |
209 | | - ) |
| 185 | + input_dim=0, |
| 186 | + output_dim=1, |
| 187 | + packed_dim=1, |
| 188 | + packed_factor=self.quant_config.pack_factor, |
| 189 | + marlin_tile_size=self.quant_config.tile_size, |
| 190 | + weight_loader=weight_loader) |
| 191 | + |
| 192 | + s_channel = ChannelQuantScaleParameter(data=torch.empty( |
| 193 | + 1, |
| 194 | + output_size_per_partition, |
| 195 | + device="cuda", |
| 196 | + dtype=torch.float, |
| 197 | + ), |
| 198 | + weight_loader=weight_loader, |
| 199 | + output_dim=1) |
210 | 200 |
|
211 | 201 | if self.quant_config.group_size == -1: |
212 | | - s_group = Parameter( |
213 | | - torch.tensor( |
214 | | - [], |
215 | | - device="cuda", |
216 | | - dtype=torch.half, |
217 | | - ), |
218 | | - requires_grad=False, |
| 202 | + s_group_data = torch.tensor( |
| 203 | + [], |
| 204 | + device="cuda", |
| 205 | + dtype=torch.half, |
219 | 206 | ) |
220 | 207 | else: |
221 | | - s_group = Parameter( |
222 | | - torch.empty( |
223 | | - input_size_per_partition // self.quant_config.group_size, |
224 | | - output_size_per_partition, |
225 | | - device="cuda", |
226 | | - dtype=torch.half, |
227 | | - ), |
228 | | - requires_grad=False, |
| 208 | + s_group_data = torch.empty( |
| 209 | + input_size_per_partition // self.quant_config.group_size, |
| 210 | + output_size_per_partition, |
| 211 | + device="cuda", |
| 212 | + dtype=torch.half, |
229 | 213 | ) |
230 | 214 |
|
231 | | - set_weight_attrs( |
232 | | - s_group, |
233 | | - { |
234 | | - "input_dim": None if self.quant_config.group_size == -1 else 0, |
235 | | - "output_dim": |
236 | | - None if self.quant_config.group_size == -1 else 1, |
237 | | - }, |
238 | | - ) |
| 215 | + s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} |
| 216 | + |
| 217 | + if self.quant_config.group_size == -1: |
| 218 | + s_group = BasevLLMParameter(**s_group_attr) |
| 219 | + else: |
| 220 | + s_group = GroupQuantScaleParameter(output_dim=1, |
| 221 | + input_dim=0, |
| 222 | + **s_group_attr) |
239 | 223 |
|
240 | 224 | # Allocate workspace (Used for internal locking mechanism) |
241 | 225 | max_workspace_size = ( |
242 | 226 | output_size_per_partition // |
243 | 227 | self.quant_config.min_n_threads) * self.quant_config.max_parallel |
244 | | - workspace = Parameter(torch.zeros(max_workspace_size, |
245 | | - device="cuda", |
246 | | - dtype=torch.int), |
247 | | - requires_grad=False) |
| 228 | + |
| 229 | + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, |
| 230 | + device="cuda", |
| 231 | + dtype=torch.int), |
| 232 | + weight_loader=weight_loader) |
248 | 233 |
|
249 | 234 | layer.register_parameter("B", qweight) |
250 | | - set_weight_attrs(qweight, extra_weight_attrs) |
251 | 235 | layer.register_parameter("s_channel", s_channel) |
252 | | - set_weight_attrs(s_channel, extra_weight_attrs) |
253 | 236 | layer.register_parameter("s_group", s_group) |
254 | | - set_weight_attrs(s_group, extra_weight_attrs) |
255 | 237 | layer.register_parameter("workspace", workspace) |
256 | | - set_weight_attrs(workspace, extra_weight_attrs) |
| 238 | + |
| 239 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 240 | + # required by torch.compile |
| 241 | + layer.B = Parameter(layer.B.data, requires_grad=False) |
| 242 | + layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) |
| 243 | + layer.s_group = Parameter(layer.s_group.data, requires_grad=False) |
| 244 | + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) |
257 | 245 |
|
258 | 246 | def apply( |
259 | 247 | self, |
|
0 commit comments