|
8 | 8 | from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase |
9 | 9 | from vllm.model_executor.layers.quantization.base_config import ( |
10 | 10 | QuantizationConfig) |
11 | | -from vllm.model_executor.utils import set_weight_attrs |
| 11 | +from vllm.model_executor.parameter import (BasevLLMParameter, |
| 12 | + ChannelQuantScaleParameter, |
| 13 | + GroupQuantScaleParameter, |
| 14 | + PackedvLLMParameter) |
12 | 15 | from vllm.scalar_type import scalar_types |
13 | 16 |
|
14 | 17 | logger = init_logger(__name__) |
@@ -149,7 +152,7 @@ def create_weights( |
149 | 152 | **extra_weight_attrs, |
150 | 153 | ): |
151 | 154 | del output_size # Unused. |
152 | | - |
| 155 | + weight_loader = extra_weight_attrs["weight_loader"] |
153 | 156 | if params_dtype != torch.float16: |
154 | 157 | raise ValueError( |
155 | 158 | f"The params dtype must be float16, but got {params_dtype}") |
@@ -187,87 +190,80 @@ def create_weights( |
187 | 190 | "Each permutation group must reside on the same gpu") |
188 | 191 |
|
189 | 192 | # Quantized 4Bit weights packed into Int32. |
190 | | - qweight = Parameter( |
191 | | - torch.empty( |
| 193 | + qweight = PackedvLLMParameter( |
| 194 | + data=torch.empty( |
192 | 195 | input_size_per_partition // self.quant_config.tile_size // 2, |
193 | 196 | output_size_per_partition * self.quant_config.tile_size // |
194 | 197 | self.quant_config.pack_factor, |
195 | 198 | device="cuda", |
196 | 199 | dtype=torch.int32, |
197 | 200 | ), |
198 | | - requires_grad=False, |
199 | | - ) |
200 | | - set_weight_attrs( |
201 | | - qweight, |
202 | | - { |
203 | | - "input_dim": 0, |
204 | | - "output_dim": 1, |
205 | | - "packed_dim": 1, |
206 | | - "pack_factor": self.quant_config.pack_factor, |
207 | | - "marlin_tile_size": self.quant_config.tile_size, |
208 | | - }, |
209 | | - ) |
| 201 | + input_dim=0, |
| 202 | + output_dim=1, |
| 203 | + packed_dim=1, |
| 204 | + packed_factor=self.quant_config.pack_factor, |
| 205 | + marlin_tile_size=self.quant_config.tile_size, |
| 206 | + weight_loader=weight_loader) |
210 | 207 |
|
211 | 208 | # Meta |
212 | | - meta = Parameter( |
213 | | - torch.empty( |
214 | | - input_size_per_partition // 8 // 2 // 2, |
215 | | - output_size_per_partition * 2, |
216 | | - device="cuda", |
217 | | - dtype=torch.int16, |
218 | | - ), |
219 | | - requires_grad=False, |
220 | | - ) |
221 | | - set_weight_attrs( |
222 | | - meta, |
223 | | - { |
224 | | - "input_dim": 0, |
225 | | - "packed_dim": 1, |
226 | | - "pack_factor": 1, |
227 | | - "output_dim": 1, |
228 | | - "marlin_tile_size": 2, |
229 | | - }, |
230 | | - ) |
| 209 | + meta = PackedvLLMParameter(data=torch.empty( |
| 210 | + input_size_per_partition // 8 // 2 // 2, |
| 211 | + output_size_per_partition * 2, |
| 212 | + device="cuda", |
| 213 | + dtype=torch.int16, |
| 214 | + ), |
| 215 | + input_dim=0, |
| 216 | + output_dim=1, |
| 217 | + packed_dim=1, |
| 218 | + packed_factor=1, |
| 219 | + marlin_tile_size=2, |
| 220 | + weight_loader=weight_loader) |
231 | 221 |
|
232 | 222 | # Determine if channelwise or not |
233 | 223 | input_groups = (1 if self.quant_config.group_size == -1 else |
234 | 224 | input_size_per_partition // |
235 | 225 | self.quant_config.group_size) |
236 | 226 |
|
237 | | - scales = Parameter( |
| 227 | + weight_scale_args = { |
| 228 | + "data": |
238 | 229 | torch.empty( |
239 | 230 | input_groups, |
240 | 231 | output_size_per_partition, |
241 | 232 | device="cuda", |
242 | 233 | dtype=params_dtype, |
243 | 234 | ), |
244 | | - requires_grad=False, |
245 | | - ) |
246 | | - set_weight_attrs( |
247 | | - scales, |
248 | | - { |
249 | | - "input_dim": None if input_groups == 1 else 0, |
250 | | - "output_dim": 1, |
251 | | - }, |
252 | | - ) |
| 235 | + "weight_loader": |
| 236 | + weight_loader |
| 237 | + } |
| 238 | + if input_groups == 1: |
| 239 | + scales = ChannelQuantScaleParameter(output_dim=1, |
| 240 | + **weight_scale_args) |
| 241 | + else: |
| 242 | + scales = GroupQuantScaleParameter(output_dim=1, |
| 243 | + input_dim=0, |
| 244 | + **weight_scale_args) |
253 | 245 |
|
254 | 246 | # Allocate workspace (Used for internal locking mechanism) |
255 | 247 | max_workspace_size = ( |
256 | 248 | output_size_per_partition // |
257 | 249 | self.quant_config.min_n_threads) * self.quant_config.max_parallel |
258 | | - workspace = Parameter(torch.zeros(max_workspace_size, |
259 | | - device="cuda", |
260 | | - dtype=torch.int), |
261 | | - requires_grad=False) |
| 250 | + |
| 251 | + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, |
| 252 | + device="cuda", |
| 253 | + dtype=torch.int), |
| 254 | + weight_loader=weight_loader) |
262 | 255 |
|
263 | 256 | layer.register_parameter("B_24", qweight) |
264 | | - set_weight_attrs(qweight, extra_weight_attrs) |
265 | 257 | layer.register_parameter("B_meta", meta) |
266 | | - set_weight_attrs(meta, extra_weight_attrs) |
267 | 258 | layer.register_parameter("s", scales) |
268 | | - set_weight_attrs(scales, extra_weight_attrs) |
269 | 259 | layer.register_parameter("workspace", workspace) |
270 | | - set_weight_attrs(workspace, extra_weight_attrs) |
| 260 | + |
| 261 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 262 | + # required by torch.compile |
| 263 | + layer.B_24 = Parameter(layer.B_24.data, requires_grad=False) |
| 264 | + layer.s = Parameter(layer.s.data, requires_grad=False) |
| 265 | + layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False) |
| 266 | + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) |
271 | 267 |
|
272 | 268 | def apply( |
273 | 269 | self, |
|
0 commit comments