Skip to content

Commit 562d51d

Browse files
dsikkamgoin
authored andcommitted
[Misc] Update gptq_marlin_24 to use vLLMParameters (vllm-project#7762)
Co-authored-by: Michael Goin <[email protected]> Signed-off-by: LeiWang1999 <[email protected]>
1 parent 5978ec6 commit 562d51d

File tree

2 files changed

+50
-54
lines changed

2 files changed

+50
-54
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
WEIGHT_LOADER_V2_SUPPORTED = [
2424
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
2525
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
26-
"MarlinLinearMethod", "QQQLinearMethod"
26+
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod"
2727
]
2828

2929

vllm/model_executor/layers/quantization/gptq_marlin_24.py

Lines changed: 49 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
99
from vllm.model_executor.layers.quantization.base_config import (
1010
QuantizationConfig)
11-
from vllm.model_executor.utils import set_weight_attrs
11+
from vllm.model_executor.parameter import (BasevLLMParameter,
12+
ChannelQuantScaleParameter,
13+
GroupQuantScaleParameter,
14+
PackedvLLMParameter)
1215
from vllm.scalar_type import scalar_types
1316

1417
logger = init_logger(__name__)
@@ -149,7 +152,7 @@ def create_weights(
149152
**extra_weight_attrs,
150153
):
151154
del output_size # Unused.
152-
155+
weight_loader = extra_weight_attrs["weight_loader"]
153156
if params_dtype != torch.float16:
154157
raise ValueError(
155158
f"The params dtype must be float16, but got {params_dtype}")
@@ -187,87 +190,80 @@ def create_weights(
187190
"Each permutation group must reside on the same gpu")
188191

189192
# Quantized 4Bit weights packed into Int32.
190-
qweight = Parameter(
191-
torch.empty(
193+
qweight = PackedvLLMParameter(
194+
data=torch.empty(
192195
input_size_per_partition // self.quant_config.tile_size // 2,
193196
output_size_per_partition * self.quant_config.tile_size //
194197
self.quant_config.pack_factor,
195198
device="cuda",
196199
dtype=torch.int32,
197200
),
198-
requires_grad=False,
199-
)
200-
set_weight_attrs(
201-
qweight,
202-
{
203-
"input_dim": 0,
204-
"output_dim": 1,
205-
"packed_dim": 1,
206-
"pack_factor": self.quant_config.pack_factor,
207-
"marlin_tile_size": self.quant_config.tile_size,
208-
},
209-
)
201+
input_dim=0,
202+
output_dim=1,
203+
packed_dim=1,
204+
packed_factor=self.quant_config.pack_factor,
205+
marlin_tile_size=self.quant_config.tile_size,
206+
weight_loader=weight_loader)
210207

211208
# Meta
212-
meta = Parameter(
213-
torch.empty(
214-
input_size_per_partition // 8 // 2 // 2,
215-
output_size_per_partition * 2,
216-
device="cuda",
217-
dtype=torch.int16,
218-
),
219-
requires_grad=False,
220-
)
221-
set_weight_attrs(
222-
meta,
223-
{
224-
"input_dim": 0,
225-
"packed_dim": 1,
226-
"pack_factor": 1,
227-
"output_dim": 1,
228-
"marlin_tile_size": 2,
229-
},
230-
)
209+
meta = PackedvLLMParameter(data=torch.empty(
210+
input_size_per_partition // 8 // 2 // 2,
211+
output_size_per_partition * 2,
212+
device="cuda",
213+
dtype=torch.int16,
214+
),
215+
input_dim=0,
216+
output_dim=1,
217+
packed_dim=1,
218+
packed_factor=1,
219+
marlin_tile_size=2,
220+
weight_loader=weight_loader)
231221

232222
# Determine if channelwise or not
233223
input_groups = (1 if self.quant_config.group_size == -1 else
234224
input_size_per_partition //
235225
self.quant_config.group_size)
236226

237-
scales = Parameter(
227+
weight_scale_args = {
228+
"data":
238229
torch.empty(
239230
input_groups,
240231
output_size_per_partition,
241232
device="cuda",
242233
dtype=params_dtype,
243234
),
244-
requires_grad=False,
245-
)
246-
set_weight_attrs(
247-
scales,
248-
{
249-
"input_dim": None if input_groups == 1 else 0,
250-
"output_dim": 1,
251-
},
252-
)
235+
"weight_loader":
236+
weight_loader
237+
}
238+
if input_groups == 1:
239+
scales = ChannelQuantScaleParameter(output_dim=1,
240+
**weight_scale_args)
241+
else:
242+
scales = GroupQuantScaleParameter(output_dim=1,
243+
input_dim=0,
244+
**weight_scale_args)
253245

254246
# Allocate workspace (Used for internal locking mechanism)
255247
max_workspace_size = (
256248
output_size_per_partition //
257249
self.quant_config.min_n_threads) * self.quant_config.max_parallel
258-
workspace = Parameter(torch.zeros(max_workspace_size,
259-
device="cuda",
260-
dtype=torch.int),
261-
requires_grad=False)
250+
251+
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
252+
device="cuda",
253+
dtype=torch.int),
254+
weight_loader=weight_loader)
262255

263256
layer.register_parameter("B_24", qweight)
264-
set_weight_attrs(qweight, extra_weight_attrs)
265257
layer.register_parameter("B_meta", meta)
266-
set_weight_attrs(meta, extra_weight_attrs)
267258
layer.register_parameter("s", scales)
268-
set_weight_attrs(scales, extra_weight_attrs)
269259
layer.register_parameter("workspace", workspace)
270-
set_weight_attrs(workspace, extra_weight_attrs)
260+
261+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
262+
# required by torch.compile
263+
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
264+
layer.s = Parameter(layer.s.data, requires_grad=False)
265+
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
266+
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
271267

272268
def apply(
273269
self,

0 commit comments

Comments
 (0)