-
Notifications
You must be signed in to change notification settings - Fork 31.9k
SINQ quantization strategy integration (adapted for Transformers V5) #43112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 43 commits
3c49a27
5cab0cb
bcb1d6f
2f054e9
f12b58d
82bcaa9
296aec7
d34764e
366b1df
00249ad
638e83f
ff40cc3
462b685
5309c4f
31b7699
6375a9d
7cc0c19
559e1d9
5d6b840
b3e7685
d914882
4624e0e
3fc92ff
f182564
0ad2a84
6b7f0b7
02c2dc4
3b60f32
233859a
9525baf
50a1fb0
3964e5c
a27526b
8d79c14
46383e2
6f7a09e
aaee212
efc96bc
b4c11a2
1063d0d
d7dc7ff
31ca4f7
1b90f6e
c471b35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| [](https://arxiv.org/abs/2509.22944) | ||
| [](https://opensource.org/licenses/Apache-2.0) | ||
| [](https://github.com/huawei-csl/SINQ/stargazers) | ||
| [](https://huggingface.co/huawei-csl) | ||
|
|
||
| # SINQ | ||
|
|
||
| [Sinkhorn-Normalized Quantization (SINQ)](https://github.com/huawei-csl/SINQ/tree/main) is a fast, plug-and-play, model-agnostic quantization technique delivering state-of-the-art performance for Large Language Models without sacrificing accuracy. | ||
|
|
||
| ### 🔍 What You’ll Find Here | ||
|
|
||
| - [1. Quantize (and save) any LLM with SINQ](#1-quantize-any-llm-with-sinq) | ||
| - [2. How to Cite This Work](#2-how-to-cite-this-work) | ||
| - [3. Current Limitations](#3-current-limitations) | ||
|
|
||
| #### 📊 Feature Comparison: SINQ vs HQQ _(calibration-free)_ and A-SINQ vs AWQ _(calibrated)_ | ||
|
|
||
| | Feature | **SINQ** | **HQQ** | **A-SINQ** | **AWQ** | | ||
| |------------|:--------:|:--------:|:----------:|:-------:| | ||
| | 🎯 Calibration | Calibration-free | Calibration-free | Calibrated | Calibrated | | ||
| | 🧮 Quantization Type | Symmetric & Asymmetric | Asymmetric only | Symmetric & Asymmetric | Symmetric & Asymmetric | | ||
| | 📦 NF4 Support | **Yes** | No | **Yes** | No | | ||
| | ⚡ Quantization Speed | ~2× **Faster** than HQQ | Slower | ~4× **Faster** than AWQ | Slower | | ||
| | 📈 Model Quality | **Higher** | Lower | **Higher** | Lower | | ||
|
|
||
|
|
||
| 📄 **Want to know more?** | ||
| - Read our paper on [**arXiv**](http://arxiv.org/abs/2509.22944) | ||
| - Check the official [**SINQ**](https://github.com/huawei-csl/SINQ/tree/main) github repository | ||
|
|
||
| --- | ||
|
|
||
| ## 1. Quantize any LLM with SINQ | ||
|
|
||
| ### Setup & Quick Start | ||
|
|
||
| First, install the package. It can be done in two ways: | ||
| - From source using the official Github repository [**SINQ**](https://github.com/huawei-csl/SINQ/tree/main) **[Recommended]** | ||
| - Using pip package: | ||
| ```bash | ||
| pip install sinq | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ### Quantize in a few lines | ||
|
|
||
| Quantizing any 🤗 Hugging Face model with SINQ is simple and takes only a few lines of code. | ||
| First, create a [`SinqConfig`] and specify the following parameters: | ||
|
|
||
| | Flag | Description | Type | Options | Default | | ||
| |------|-------------|---------|---------|----------| | ||
| | `--nbits` | Bit-width for weight quantization | int | 2, 3, 4, 5, 6, 8 | 4 | | ||
| | `--tiling_mode` | Weight matrix tiling strategy | str | 1D, 2D | 1D | | ||
| | `--group_size` | Weights per quantization group | int | 64, 128 | 64 | | ||
| | `--method` | Quantization method | str | sinq, asinq | sinq | | ||
| | `--modules_to_not_convert` | List of the layers that are NOT quantize | List of str | [lm_head, ...] | [lm_head] | | ||
| | `--device` | Device on which the model is loaded | str | cpu, cuda:0, cuda:1, etc | cuda:0 | | ||
|
|
||
| Then specify the model you want to quantize and pass the SinqConfig as quantization configuration option | ||
|
|
||
| ```python | ||
| import torch | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM, SinqConfig | ||
|
|
||
| model_name = "Qwen/Qwen3-1.7B" | ||
| device = "cuda:0" | ||
|
|
||
| cfg = SinqConfig( | ||
| nbits=4, | ||
| group_size=64, | ||
| tiling_mode="1D", | ||
| method="sinq", | ||
| modules_to_not_convert=["lm_head"], | ||
| device=device | ||
| ) | ||
|
|
||
| tok = AutoTokenizer.from_pretrained(model_name) | ||
| qmodel = AutoModelForCausalLM.from_pretrained( | ||
| model_name, | ||
| quantization_config=cfg, | ||
| dtype=torch.bfloat16 | ||
| ) | ||
|
|
||
| ``` | ||
|
|
||
| ✅ That’s it. Your model is now quantized with **SINQ** and ready for inference or saving. | ||
|
|
||
| > Check our official [**SINQ**](https://github.com/huawei-csl/SINQ/tree/main) github repository to stay updated! | ||
|
|
||
| --- | ||
|
|
||
| ### Save & reload | ||
|
|
||
| If you want to reuse a quantized model later, save it to disk or push it on the HuggingFace Hub and reload it without needing base FP weights. | ||
| If you installed SINQ from source you should call *patch_hf_pretrained_io* function: | ||
| ```python | ||
| from sinq.hf_io import patch_hf_pretrained_io | ||
| patch_hf_pretrained_io() | ||
|
||
| # Save sinq quantized model | ||
| model.save_pretrained("/path/to/save/qwen3-1.7B-sinq-4bit") | ||
| model.push_to_hub("HF_Hub_username/qwen3-1.7B-sinq-4bit") | ||
| tokenizer.push_to_hub("HF_Hub_username/qwen3-1.7B-sinq-4bit") | ||
| # Reload a sinq quantized model | ||
| hf_hub_model = "HF_Hub_username/qwen3-1.7B-sinq-4bit" | ||
| tokenizer = AutoTokenizer.from_pretrained(hf_hub_model) | ||
| model = AutoModelForCausalLM.from_pretrained(hf_hub_model) | ||
| ``` | ||
| Otherwise, if you installed SINQ through pip, you can simply use HF built-in functions: | ||
|
|
||
| ```python | ||
| # --- Save to a folder (sharded safetensors) --- | ||
|
|
||
| # 'model' must already be SINQ-quantized | ||
| # Locally save | ||
| qmodel.save_pretrained("/path/to/save/qwen3-1.7B-sinq-4bit") | ||
| # Push to the Hub | ||
| qmodel.push_to_hub("HF_Hub_username/qwen3-1.7B-sinq-4bit") | ||
| tok.push_to_hub("HF_Hub_username/qwen3-1.7B-sinq-4bit") | ||
|
|
||
| # --- Reload later-- | ||
|
|
||
| save_dir = "/path/to/save/qwen3-1.7B-sinq-4bit" | ||
| hf_hub_model = "HF_Hub_username/qwen3-1.7B-sinq-4bit" | ||
|
|
||
| # From local directory | ||
| tok = AutoTokenizer.from_pretrained(save_dir) | ||
| qmodel = AutoModelForCausalLM.from_pretrained(save_dir) | ||
|
|
||
| # From HF Hub | ||
| tok = AutoTokenizer.from_pretrained(hf_hub_model) | ||
| qmodel = AutoModelForCausalLM.from_pretrained(hf_hub_model) | ||
|
|
||
| ``` | ||
|
|
||
| ✅ Your model is now loaded and ready for inference! | ||
|
|
||
| > Note: If the model has been quantized in 4 bit and `gemlite` library is installed, gemlite faster kernel is used to run the inference. | ||
|
|
||
| --- | ||
|
|
||
| ### Compatible with [`lm-eval`](https://github.com/EleutherAI/lm-evaluation-harness) evaluation framework | ||
|
|
||
| Below is a minimal example showing how to evaluate a SINQ-quantized model on a benchmark dataset: | ||
|
|
||
| ```python | ||
| from lm_eval import evaluator | ||
| from lm_eval.models.huggingface import HFLM | ||
|
|
||
| # Wrap the already quantized model and tokenizer with HFLM | ||
| lm = HFLM(pretrained=qmodel, tokenizer=tok, device=device) | ||
|
|
||
| # Evaluate (many tasks available on lm-eval such as MMLU and HellaSwag) | ||
| results = evaluator.simple_evaluate( | ||
| model=lm, | ||
| tasks=["wikitext"], # small and fast benchmark | ||
| device=device | ||
| ) | ||
| ``` | ||
|
|
||
| ## 2. How to Cite This Work | ||
|
|
||
| If you find **SINQ** useful in your research or applications | ||
| - Support our project by putting a star ⭐️ in the [**SINQ**](https://github.com/huawei-csl/SINQ/tree/main) github repository | ||
| - Please cite our <a href="http://arxiv.org/abs/2509.22944" target="_blank"><strong>paper</strong></a>: | ||
|
|
||
| ```bibtex | ||
| @misc{muller2025sinq, | ||
| title={SINQ: Sinkhorn-Normalized Quantization for Calibration-Free Low-Precision LLM Weights}, | ||
| author={Lorenz K. Muller and Philippe Bich and Jiawei Zhuang and Ahmet Celik and Luca Benfenati and Lukas Cavigelli}, | ||
| year={2025}, | ||
| eprint={2509.22944}, | ||
| archivePrefix={arXiv}, | ||
| primaryClass={cs.LG}, | ||
| url={http://arxiv.org/abs/2509.22944} | ||
| } | ||
| ``` | ||
|
|
||
| ## 3. Current Limitations | ||
|
|
||
| Currently, the A-SINQ method is not supported in Hugging Face. Please refer to the official [SINQ repository](https://github.com/huawei-csl/SINQ/tree/main) to quantize a model with this strategy. | ||
| At the moment the SINQ quantization strategy and SINQ quantized models do not support Multi-GPU option, so if your system counts multiple GPUs please specify which one should be used. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2025 The HuggingFace Inc. team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Optional, Dict, Any | ||
|
|
||
| from transformers.utils import is_torch_available, logging | ||
|
|
||
| from ..core_model_loading import ConversionOps | ||
| from ..quantizers.quantizers_utils import get_module_from_name | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| class SinqQuantize(ConversionOps): | ||
| """ | ||
| Param-level ConversionOp for SINQ (from FP weights). | ||
|
|
||
| At load time, for each `Linear.weight` that should be quantized: | ||
| - The SINQLinear module already exists (created in _process_model_before_weight_loading) | ||
| - We just call quantize() on it with the loaded weight tensor | ||
| """ | ||
|
|
||
| def __init__(self, hf_quantizer: "SinqHfQuantizer"): | ||
| self.hf_quantizer = hf_quantizer | ||
|
|
||
| def convert( | ||
| self, | ||
| input_dict: Dict[str, Any], | ||
| model: Optional["torch.nn.Module"] = None, | ||
| full_layer_name: str | None = None, | ||
| missing_keys=None, | ||
| **kwargs, | ||
| ) -> Dict[str, "torch.Tensor"]: | ||
|
|
||
| _, values = next(iter(input_dict.items())) | ||
| weight_tensor = values[0] if isinstance(values, list) else values | ||
|
|
||
| module, tensor_name = get_module_from_name(model, full_layer_name) | ||
|
|
||
| module.quantize(weight_tensor) | ||
|
|
||
| if missing_keys is not None: | ||
| missing_keys.discard(full_layer_name) | ||
|
|
||
| module._is_hf_initialized = True | ||
|
|
||
| return {} | ||
|
|
||
| class SinqDeserialize(ConversionOps): | ||
| """ | ||
| ConversionOp for loading *pre-quantized* SINQ checkpoints. | ||
|
|
||
| Checkpoint layout (what `SINQLinear.state_dict` produces) is, per module: | ||
| <prefix>.W_q | ||
| <prefix>.bias | ||
| <prefix>.meta | ||
|
|
||
| WeightConverter in the quantizer is configured so that: | ||
| - we group ".W_q", ".meta", ".bias" as input_dict | ||
| - conceptually treat them as belonging to "<prefix>.weight" | ||
| - and call this SinqDeserialize.convert to load the state into the existing SINQLinear. | ||
|
|
||
| The returned dict is {} because we load directly into the module. | ||
| """ | ||
|
|
||
| def __init__(self, hf_quantizer: "SinqHfQuantizer"): | ||
| self.hf_quantizer = hf_quantizer | ||
|
|
||
| def convert( | ||
| self, | ||
| input_dict: Dict[str, Any], | ||
| model: Optional["torch.nn.Module"] = None, | ||
| full_layer_name: str | None = None, | ||
| **kwargs, | ||
| ) -> Dict[str, "torch.Tensor"]: | ||
|
|
||
| for k, v in list(input_dict.items()): | ||
| if isinstance(v, list): | ||
| input_dict[k] = v[0] | ||
|
|
||
| W_q = input_dict.get(".W_q", None) | ||
| meta = input_dict.get(".meta", None) | ||
| bias = input_dict.get(".bias", None) | ||
|
|
||
| if W_q is None or meta is None: | ||
| v = next(iter(input_dict.values())) | ||
| if isinstance(v, list): | ||
| v = v[0] | ||
| return {full_layer_name: v} | ||
|
|
||
| module, _ = get_module_from_name(model, full_layer_name) | ||
|
|
||
| state = { | ||
| "W_q": W_q, | ||
| "meta": meta, | ||
| } | ||
| if bias is not None: | ||
| state["bias"] = bias | ||
|
|
||
| module.load_state_dict(state) | ||
| module._is_hf_initialized = True | ||
|
|
||
| return {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this one, the user can just pass
device_mapinfrom_pretrained.