Skip to content

Commit 50428b7

Browse files
dsikkaAlvant
authored andcommitted
[Misc] Update qqq to use vLLMParameters (vllm-project#7805)
Signed-off-by: Alvant <[email protected]>
1 parent ed30706 commit 50428b7

File tree

3 files changed

+55
-65
lines changed

3 files changed

+55
-65
lines changed

tests/weight_loading/models.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@ awq, casperhansen/mixtral-instruct-awq, main
1717
awq_marlin, casperhansen/mixtral-instruct-awq, main
1818
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
1919
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
20-
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
20+
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
21+
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
22+
qqq, HandH1998/QQQ-Llama-3-8b, main

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
WEIGHT_LOADER_V2_SUPPORTED = [
2424
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
2525
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
26-
"MarlinLinearMethod"
26+
"MarlinLinearMethod", "QQQLinearMethod"
2727
]
2828

2929

vllm/model_executor/layers/quantization/qqq.py

Lines changed: 51 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
99
from vllm.model_executor.layers.quantization.base_config import (
1010
QuantizationConfig)
11-
from vllm.model_executor.utils import set_weight_attrs
11+
from vllm.model_executor.parameter import (BasevLLMParameter,
12+
ChannelQuantScaleParameter,
13+
GroupQuantScaleParameter,
14+
PackedvLLMParameter)
1215

1316
logger = init_logger(__name__)
1417

@@ -133,6 +136,7 @@ def create_weights(
133136
params_dtype: torch.dtype,
134137
**extra_weight_attrs,
135138
):
139+
weight_loader = extra_weight_attrs["weight_loader"]
136140
if params_dtype != torch.float16:
137141
raise ValueError(
138142
f"The params dtype must be float16, but got {params_dtype}")
@@ -170,90 +174,74 @@ def create_weights(
170174
"Each permutation group must reside on the same gpu")
171175

172176
# Quantized 4Bit weights packed into Int32.
173-
qweight = Parameter(
174-
torch.empty(
177+
qweight = PackedvLLMParameter(
178+
data=torch.empty(
175179
input_size_per_partition // self.quant_config.tile_size,
176180
output_size_per_partition * self.quant_config.tile_size //
177181
self.quant_config.pack_factor,
178182
device="cuda",
179183
dtype=torch.int32,
180184
),
181-
requires_grad=False,
182-
)
183-
set_weight_attrs(
184-
qweight,
185-
{
186-
"input_dim": 0,
187-
"output_dim": 1,
188-
"packed_dim": 1,
189-
"pack_factor": self.quant_config.pack_factor,
190-
"marlin_tile_size": self.quant_config.tile_size,
191-
},
192-
)
193-
194-
s_channel = Parameter(
195-
torch.empty(
196-
1,
197-
output_size_per_partition,
198-
device="cuda",
199-
dtype=torch.float,
200-
),
201-
requires_grad=False,
202-
)
203-
set_weight_attrs(
204-
s_channel,
205-
{
206-
"input_dim": None,
207-
"output_dim": 1,
208-
},
209-
)
185+
input_dim=0,
186+
output_dim=1,
187+
packed_dim=1,
188+
packed_factor=self.quant_config.pack_factor,
189+
marlin_tile_size=self.quant_config.tile_size,
190+
weight_loader=weight_loader)
191+
192+
s_channel = ChannelQuantScaleParameter(data=torch.empty(
193+
1,
194+
output_size_per_partition,
195+
device="cuda",
196+
dtype=torch.float,
197+
),
198+
weight_loader=weight_loader,
199+
output_dim=1)
210200

211201
if self.quant_config.group_size == -1:
212-
s_group = Parameter(
213-
torch.tensor(
214-
[],
215-
device="cuda",
216-
dtype=torch.half,
217-
),
218-
requires_grad=False,
202+
s_group_data = torch.tensor(
203+
[],
204+
device="cuda",
205+
dtype=torch.half,
219206
)
220207
else:
221-
s_group = Parameter(
222-
torch.empty(
223-
input_size_per_partition // self.quant_config.group_size,
224-
output_size_per_partition,
225-
device="cuda",
226-
dtype=torch.half,
227-
),
228-
requires_grad=False,
208+
s_group_data = torch.empty(
209+
input_size_per_partition // self.quant_config.group_size,
210+
output_size_per_partition,
211+
device="cuda",
212+
dtype=torch.half,
229213
)
230214

231-
set_weight_attrs(
232-
s_group,
233-
{
234-
"input_dim": None if self.quant_config.group_size == -1 else 0,
235-
"output_dim":
236-
None if self.quant_config.group_size == -1 else 1,
237-
},
238-
)
215+
s_group_attr = {"data": s_group_data, "weight_loader": weight_loader}
216+
217+
if self.quant_config.group_size == -1:
218+
s_group = BasevLLMParameter(**s_group_attr)
219+
else:
220+
s_group = GroupQuantScaleParameter(output_dim=1,
221+
input_dim=0,
222+
**s_group_attr)
239223

240224
# Allocate workspace (Used for internal locking mechanism)
241225
max_workspace_size = (
242226
output_size_per_partition //
243227
self.quant_config.min_n_threads) * self.quant_config.max_parallel
244-
workspace = Parameter(torch.zeros(max_workspace_size,
245-
device="cuda",
246-
dtype=torch.int),
247-
requires_grad=False)
228+
229+
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
230+
device="cuda",
231+
dtype=torch.int),
232+
weight_loader=weight_loader)
248233

249234
layer.register_parameter("B", qweight)
250-
set_weight_attrs(qweight, extra_weight_attrs)
251235
layer.register_parameter("s_channel", s_channel)
252-
set_weight_attrs(s_channel, extra_weight_attrs)
253236
layer.register_parameter("s_group", s_group)
254-
set_weight_attrs(s_group, extra_weight_attrs)
255237
layer.register_parameter("workspace", workspace)
256-
set_weight_attrs(workspace, extra_weight_attrs)
238+
239+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
240+
# required by torch.compile
241+
layer.B = Parameter(layer.B.data, requires_grad=False)
242+
layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False)
243+
layer.s_group = Parameter(layer.s_group.data, requires_grad=False)
244+
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
257245

258246
def apply(
259247
self,

0 commit comments

Comments
 (0)