Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod"
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod"
]


Expand Down
102 changes: 49 additions & 53 deletions vllm/model_executor/layers/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)
Expand Down Expand Up @@ -149,7 +152,7 @@ def create_weights(
**extra_weight_attrs,
):
del output_size # Unused.

weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
Expand Down Expand Up @@ -187,87 +190,80 @@ def create_weights(
"Each permutation group must reside on the same gpu")

# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)

# Meta
meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
},
)
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)

# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)

scales = Parameter(
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)

# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)

workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)

layer.register_parameter("B_24", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("B_meta", meta)
set_weight_attrs(meta, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)

def apply(
self,
Expand Down