Skip to content

Commit 4f69d04

Browse files
committed
hotfix: increase precision of GPTQ/AWQ-Marlin
Sync with upstream change that improves the precision of the 'global_reduce' algorithm from FP16 to FP32. This solves some reported generation quality issues. Upstream issue/PR: vllm-project/vllm#6795
1 parent 4b49c50 commit 4f69d04

File tree

4 files changed

+492
-387
lines changed

4 files changed

+492
-387
lines changed

server/marlin/marlin_kernels/__init__.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import torch
22

3+
def awq_marlin_repack(
4+
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
5+
) -> torch.Tensor:
6+
"""Repack AWQ parameters for GPTQ-Marlin."""
7+
...
8+
39
def gptq_marlin_gemm(
410
a: torch.Tensor,
511
b_q_weight: torch.Tensor,
@@ -12,6 +18,8 @@ def gptq_marlin_gemm(
1218
size_n: int,
1319
size_k: int,
1420
is_k_full: bool,
21+
has_zp: bool,
22+
use_fp32_reduce: bool,
1523
) -> torch.Tensor:
1624
"""
1725
Matrix multiplication using Marlin kernels. This is an extension of

server/marlin/marlin_kernels/ext.hh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
1414
torch::Tensor &g_idx, torch::Tensor &perm,
1515
torch::Tensor &workspace, int64_t num_bits,
1616
int64_t size_m, int64_t size_n, int64_t size_k,
17-
bool is_k_full, bool has_zp);
17+
bool is_k_full, bool has_zp,
18+
bool use_fp32_reduce);
1819

1920
torch::Tensor gptq_marlin_24_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
2021
torch::Tensor &b_meta,

0 commit comments

Comments
 (0)