Skip to content

Commit a10d305

Browse files
authored
[Core] Set linear_weights directly on the layer (#3977)
1 parent 8afca50 commit a10d305

File tree

8 files changed

+114
-102
lines changed

8 files changed

+114
-102
lines changed

csrc/quantization/gptq/q_gemm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2067,7 +2067,7 @@ void gptq_shuffle
20672067
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
20682068
vllm::gptq::shuffle_exllama_weight(
20692069
(uint32_t*) q_weight.data_ptr(),
2070-
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
2070+
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
20712071
q_weight.size(0) * 32 / bit,
20722072
q_weight.size(1),
20732073
bit

tests/kernels/test_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
7373
).cuda()
7474

7575
# Load the weights
76-
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
76+
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
7777
for i in range(config.num_local_experts):
7878
weights = (hf_moe.experts[i].w1.weight.data,
7979
hf_moe.experts[i].w3.weight.data)

vllm/lora/layers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def set_mapping(
368368
def apply_weights(self, x: torch.Tensor,
369369
bias: Optional[torch.Tensor]) -> torch.Tensor:
370370
output = self.base_layer.linear_method.apply_weights(
371-
self.base_layer.linear_weights, x, bias)
371+
self.base_layer, x, bias)
372372
_apply_lora(
373373
x,
374374
self.lora_a_stacked,
@@ -402,10 +402,6 @@ def forward(self, input_):
402402
if self.base_layer.skip_bias_add else None)
403403
return output, output_bias
404404

405-
@property
406-
def linear_weights(self):
407-
return self.base_layer.linear_weights
408-
409405
@classmethod
410406
def can_replace_layer(cls, source_layer: nn.Module,
411407
lora_config: LoRAConfig, packed_modules_list: List,
@@ -505,7 +501,7 @@ def set_lora(
505501
def apply_weights(self, x: torch.Tensor,
506502
bias: Optional[torch.Tensor]) -> torch.Tensor:
507503
output = self.base_layer.linear_method.apply_weights(
508-
self.base_layer.linear_weights, x, bias)
504+
self.base_layer, x, bias)
509505
_apply_lora_packed_nslice(
510506
x,
511507
self.lora_a_stacked,
@@ -746,7 +742,7 @@ def set_lora(
746742
def apply_weights(self, x: torch.Tensor,
747743
bias: Optional[torch.Tensor]) -> torch.Tensor:
748744
output = self.base_layer.linear_method.apply_weights(
749-
self.base_layer.linear_weights, x, bias)
745+
self.base_layer, x, bias)
750746
_apply_lora_packed_nslice(
751747
x,
752748
self.lora_a_stacked,
@@ -838,7 +834,7 @@ def set_mapping(
838834

839835
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
840836
output = self.base_layer.linear_method.apply_weights(
841-
self.base_layer.linear_weights, x)
837+
self.base_layer, x)
842838
_apply_lora(
843839
x,
844840
self.lora_a_stacked,

vllm/model_executor/layers/linear.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Optional
2+
from typing import List, Optional
33

44
import torch
55
import torch.nn.functional as F
@@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
2828
"""Base class for different (maybe quantized) linear methods."""
2929

3030
@abstractmethod
31-
def create_weights(self, input_size_per_partition: int,
31+
def create_weights(self, layer: torch.nn.Module,
32+
input_size_per_partition: int,
3233
output_size_per_partition: int, input_size: int,
33-
output_size: int,
34-
params_dtype: torch.dtype) -> Dict[str, Any]:
35-
"""Create weights for a linear layer."""
34+
output_size: int, params_dtype: torch.dtype,
35+
**extra_weight_attrs):
36+
"""Create weights for a linear layer.
37+
38+
The weights will be set as attributes of the layer."""
3639
raise NotImplementedError
3740

3841
@abstractmethod
3942
def apply_weights(self,
40-
weights: Dict[str, torch.Tensor],
43+
layer: torch.nn.Module,
4144
x: torch.Tensor,
4245
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
43-
"""Apply the weights to the input tensor."""
46+
"""Apply the weights in layer to the input tensor.
47+
48+
Expects create_weights to have been called before on the layer."""
4449
raise NotImplementedError
4550

4651

@@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
5560
def __init__(self, separate_bias_add: bool = False):
5661
self.separate_bias_add = separate_bias_add
5762

58-
def create_weights(self, input_size_per_partition: int,
63+
def create_weights(self, layer: torch.nn.Module,
64+
input_size_per_partition: int,
5965
output_size_per_partition: int, input_size: int,
60-
output_size: int,
61-
params_dtype: torch.dtype) -> Dict[str, Any]:
66+
output_size: int, params_dtype: torch.dtype,
67+
**extra_weight_attrs):
6268
weight = Parameter(torch.empty(output_size_per_partition,
6369
input_size_per_partition,
6470
dtype=params_dtype),
6571
requires_grad=False)
6672
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
67-
return {"weight": weight}
73+
layer.register_parameter("weight", weight)
74+
set_weight_attrs(weight, extra_weight_attrs)
6875

6976
def apply_weights(self,
70-
weights: Dict[str, torch.Tensor],
77+
layer: torch.nn.Module,
7178
x: torch.Tensor,
7279
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
73-
weight = weights["weight"]
80+
weight = layer.weight
7481
if self.separate_bias_add:
7582
if bias is not None:
7683
return F.linear(x, weight) + bias
@@ -111,12 +118,9 @@ def __init__(
111118
if linear_method is None:
112119
linear_method = UnquantizedLinearMethod()
113120
self.linear_method = linear_method
114-
self.linear_weights = self.linear_method.create_weights(
115-
self.input_size, self.output_size, self.input_size,
116-
self.output_size, self.params_dtype)
117-
for name, weight in self.linear_weights.items():
118-
if isinstance(weight, torch.Tensor):
119-
self.register_parameter(name, weight)
121+
self.linear_method.create_weights(self, self.input_size,
122+
self.output_size, self.input_size,
123+
self.output_size, self.params_dtype)
120124
if bias:
121125
self.bias = Parameter(
122126
torch.empty(self.output_size, dtype=self.params_dtype))
@@ -126,7 +130,7 @@ def __init__(
126130

127131
def forward(self, x: torch.Tensor) -> torch.Tensor:
128132
bias = self.bias if not self.skip_bias_add else None
129-
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
133+
output = self.linear_method.apply_weights(self, x, bias)
130134
output_bias = self.bias if self.skip_bias_add else None
131135
return output, output_bias
132136

@@ -177,13 +181,13 @@ def __init__(
177181
if linear_method is None:
178182
linear_method = UnquantizedLinearMethod()
179183
self.linear_method = linear_method
180-
self.linear_weights = self.linear_method.create_weights(
181-
self.input_size, self.output_size_per_partition, self.input_size,
182-
self.output_size, self.params_dtype)
183-
for name, weight in self.linear_weights.items():
184-
if isinstance(weight, torch.Tensor):
185-
self.register_parameter(name, weight)
186-
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
184+
self.linear_method.create_weights(self,
185+
self.input_size,
186+
self.output_size_per_partition,
187+
self.input_size,
188+
self.output_size,
189+
self.params_dtype,
190+
weight_loader=self.weight_loader)
187191
if bias:
188192
self.bias = Parameter(
189193
torch.empty(self.output_size_per_partition,
@@ -211,8 +215,7 @@ def forward(self, input_):
211215
bias = self.bias if not self.skip_bias_add else None
212216

213217
# Matrix multiply.
214-
output_parallel = self.linear_method.apply_weights(
215-
self.linear_weights, input_, bias)
218+
output_parallel = self.linear_method.apply_weights(self, input_, bias)
216219
if self.gather_output:
217220
# All-gather across the partitions.
218221
output = tensor_model_parallel_all_gather(output_parallel)
@@ -523,13 +526,13 @@ def __init__(
523526
if linear_method is None:
524527
linear_method = UnquantizedLinearMethod()
525528
self.linear_method = linear_method
526-
self.linear_weights = self.linear_method.create_weights(
527-
self.input_size_per_partition, self.output_size, self.input_size,
528-
self.output_size, self.params_dtype)
529-
for name, weight in self.linear_weights.items():
530-
if isinstance(weight, torch.Tensor):
531-
self.register_parameter(name, weight)
532-
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
529+
self.linear_method.create_weights(self,
530+
self.input_size_per_partition,
531+
self.output_size,
532+
self.input_size,
533+
self.output_size,
534+
self.params_dtype,
535+
weight_loader=self.weight_loader)
533536

534537
if not reduce_results and (bias and not skip_bias_add):
535538
raise ValueError("When not reduce the results, adding bias to the "
@@ -569,7 +572,7 @@ def forward(self, input_):
569572

570573
# Matrix multiply.
571574
output_parallel = self.linear_method.apply_weights(
572-
self.linear_weights, input_parallel)
575+
self, input_parallel)
573576
if self.reduce_results and self.tp_size > 1:
574577
output_ = tensor_model_parallel_all_reduce(output_parallel)
575578
else:

vllm/model_executor/layers/quantization/awq.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
7979
def __init__(self, quant_config: AWQConfig):
8080
self.quant_config = quant_config
8181

82-
def create_weights(self, input_size_per_partition: int,
82+
def create_weights(self, layer: torch.nn.Module,
83+
input_size_per_partition: int,
8384
output_size_per_partition: int, input_size: int,
84-
output_size: int,
85-
params_dtype: torch.dtype) -> Dict[str, Any]:
85+
output_size: int, params_dtype: torch.dtype,
86+
**extra_weight_attrs):
8687
if input_size_per_partition % self.quant_config.group_size != 0:
8788
raise ValueError(
8889
"The input size is not aligned with the quantized "
@@ -136,19 +137,21 @@ def create_weights(self, input_size_per_partition: int,
136137
"input_dim": 0,
137138
"output_dim": 1,
138139
})
139-
return {
140-
"qweight": qweight,
141-
"qzeros": qzeros,
142-
"scales": scales,
143-
}
140+
141+
layer.register_parameter("qweight", qweight)
142+
set_weight_attrs(qweight, extra_weight_attrs)
143+
layer.register_parameter("qzeros", qzeros)
144+
set_weight_attrs(qzeros, extra_weight_attrs)
145+
layer.register_parameter("scales", scales)
146+
set_weight_attrs(scales, extra_weight_attrs)
144147

145148
def apply_weights(self,
146-
weights: Dict[str, Any],
149+
layer: torch.nn.Module,
147150
x: torch.Tensor,
148151
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
149-
qweight = weights["qweight"]
150-
scales = weights["scales"]
151-
qzeros = weights["qzeros"]
152+
qweight = layer.qweight
153+
scales = layer.scales
154+
qzeros = layer.qzeros
152155
pack_factor = self.quant_config.pack_factor
153156
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
154157
reshaped_x = x.reshape(-1, x.shape[-1])
@@ -163,5 +166,5 @@ def apply_weights(self,
163166
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
164167
pack_factor)
165168
if bias is not None:
166-
out = out + bias
169+
out.add_(bias)
167170
return out.reshape(out_shape)

vllm/model_executor/layers/quantization/gptq.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig):
8989

9090
def create_weights(
9191
self,
92+
layer: torch.nn.Module,
9293
input_size_per_partition: int,
9394
output_size_per_partition: int,
9495
input_size: int,
9596
output_size: int,
9697
params_dtype: torch.dtype,
97-
) -> Dict[str, Any]:
98+
**extra_weight_attrs,
99+
):
98100
del output_size # Unused.
99101
if input_size_per_partition % self.quant_config.group_size != 0:
100102
raise ValueError(
@@ -179,37 +181,40 @@ def create_weights(
179181
"input_dim": scale_and_zero_input_dim,
180182
"output_dim": 1,
181183
})
182-
return {
183-
"qweight": qweight,
184-
"g_idx": g_idx,
185-
"qzeros": qzeros,
186-
"scales": scales,
187-
"exllama_state": exllama_state,
188-
}
184+
185+
layer.register_parameter("qweight", qweight)
186+
set_weight_attrs(qweight, extra_weight_attrs)
187+
layer.register_parameter("g_idx", g_idx)
188+
set_weight_attrs(g_idx, extra_weight_attrs)
189+
layer.register_parameter("qzeros", qzeros)
190+
set_weight_attrs(qzeros, extra_weight_attrs)
191+
layer.register_parameter("scales", scales)
192+
set_weight_attrs(scales, extra_weight_attrs)
193+
194+
layer.exllama_state = exllama_state
189195

190196
def apply_weights(self,
191-
weights: Dict[str, Any],
197+
layer: torch.nn.Module,
192198
x: torch.Tensor,
193199
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
194-
qweight = weights["qweight"]
200+
qweight = layer.qweight
195201
out_shape = x.shape[:-1] + (qweight.shape[-1], )
196202
reshaped_x = x.reshape(-1, x.shape[-1])
197203
# exllama needs to shuffle the weight after the weight is loaded
198204
# here we do the shuffle on first forward pass
199-
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
205+
if layer.exllama_state == ExllamaState.UNINITIALIZED:
200206
if self.quant_config.desc_act:
201-
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
202-
torch.int)
207+
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
203208
else:
204-
weights["g_idx"] = torch.empty((1, 1), device="meta")
205-
weights["exllama_state"] = ExllamaState.READY
206-
ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
209+
layer.g_idx.data = torch.empty((0, ),
210+
device=layer.g_idx.device)
211+
layer.exllama_state = ExllamaState.READY
212+
ops.gptq_shuffle(layer.qweight, layer.g_idx,
207213
self.quant_config.weight_bits)
208-
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
209-
weights["qzeros"], weights["scales"],
210-
weights["g_idx"],
211-
weights["exllama_state"] == ExllamaState.READY,
214+
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
215+
layer.scales, layer.g_idx,
216+
layer.exllama_state == ExllamaState.READY,
212217
self.quant_config.weight_bits)
213218
if bias is not None:
214-
output = output + bias
219+
output.add_(bias)
215220
return output.reshape(out_shape)

vllm/model_executor/layers/quantization/marlin.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,14 @@ def __init__(self, quant_config: MarlinConfig):
9191

9292
def create_weights(
9393
self,
94+
layer: torch.nn.Module,
9495
input_size_per_partition: int,
9596
output_size_per_partition: int,
9697
input_size: int,
9798
output_size: int,
9899
params_dtype: torch.dtype,
99-
) -> Dict[str, Any]:
100+
**extra_weight_attrs,
101+
):
100102
del output_size # Unused.
101103

102104
if params_dtype != torch.float16:
@@ -187,21 +189,22 @@ def create_weights(
187189
dtype=torch.int),
188190
requires_grad=False)
189191

190-
return {
191-
"B": qweight,
192-
"s": scales,
193-
"workspace": workspace,
194-
}
192+
layer.register_parameter("B", qweight)
193+
set_weight_attrs(qweight, extra_weight_attrs)
194+
layer.register_parameter("s", scales)
195+
set_weight_attrs(scales, extra_weight_attrs)
196+
layer.register_parameter("workspace", workspace)
197+
set_weight_attrs(workspace, extra_weight_attrs)
195198

196199
def apply_weights(
197200
self,
198-
weights: Dict[str, Any],
201+
layer: torch.nn.Module,
199202
x: torch.Tensor,
200203
bias: Optional[torch.Tensor] = None,
201204
) -> torch.Tensor:
202-
qweight = weights["B"]
203-
scales = weights["s"]
204-
workspace = weights["workspace"]
205+
qweight = layer.B
206+
scales = layer.s
207+
workspace = layer.workspace
205208

206209
x_2d = x.view(-1, x.shape[-1])
207210

0 commit comments

Comments
 (0)