Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,5 +159,19 @@ def test_compressed_tensors_fp8(vllm_runner):
def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output


def test_compressed_tensors_actorder_weight(vllm_runner):
model_path = "kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-weight-e2e"
with vllm_runner(model_path) as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output


def test_compressed_tensors_actorder_group(vllm_runner):
model_path = "kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-group-e2e"
with vllm_runner(model_path) as llm:
output = llm.generate_greedy("Hello world!", max_tokens=20)
assert output
2 changes: 2 additions & 0 deletions tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-weight-e2e, main
compressed-tensors, kylesayrs/TinyLlama-1.1B-Chat-v1.0-actorder-group-e2e, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
Expand Down Expand Up @@ -119,9 +119,15 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
dtype=torch.int64),
weight_loader=weight_loader)

# group index (for activation reordering)
weight_g_idx = BasevLLMParameter(data=torch.full(
(input_size_per_partition, ), -1, dtype=torch.int32),
weight_loader=weight_loader)

layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_g_idx", weight_g_idx)

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
Expand All @@ -137,9 +143,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)

# Act-order not supported in compressed-tensors yet, so set to empty.
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# Handle sorting for activation reordering if needed.
has_g_idx = -1 not in layer.weight_g_idx
if has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
Expand All @@ -161,7 +173,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Permute scales from compressed-tensors format to marlin format.
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=layer.input_size_per_partition,
size_k=(layer.input_size
if has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
Expand All @@ -174,7 +187,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.g_idx,
g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN = "token"


class ActivationOrdering(str, Enum):
"""
Enum storing strategies for activation ordering

Group: reorder groups and weight\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder\n
"""

GROUP = "group"
WEIGHT = "weight"


class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
Expand All @@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""

num_bits: int = 8
Expand All @@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
actorder: Optional[ActivationOrdering] = None
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,16 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)


def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def marlin_is_k_full(has_g_idx: bool, is_row_parallel: bool) -> bool:
return (not has_g_idx) or (not is_row_parallel)


def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
def marlin_repeat_scales_on_all_ranks(has_g_idx: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# Need to repeat scales on every rank if actorder or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
return has_g_idx or (is_channelwise and is_row_parallel)


def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
Expand Down