Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions src/pruna/algorithms/quantization/backends/ganq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#### Guide to use GANQ quantization

**Quantize a model**

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from pruna.config.smash_config import SmashConfig
from pruna.data.pruna_datamodule import PrunaDataModule


import torch
from transformers import AutoModelForCausalLM

import torch
from pruna.algorithms.quantization.ganq import GANQQuantizer

# -------------------------------------------------------------------------
# 1. Load model and tokenizer
# -------------------------------------------------------------------------
model_name = "HuggingFaceTB/SmolLM2-135M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto"
)
model.eval()

# -------------------------------------------------------------------------
# 2. Build SmashConfig for Pruna Quantizer
# -------------------------------------------------------------------------
smash_config = SmashConfig(
batch_size=4,
device="cuda" if torch.cuda.is_available() else "cpu",
cache_dir_prefix="./cache_ganq",
)

# Add tokenizer
smash_config.add_tokenizer(tokenizer)

# Use Pruna's built-in WikiText dataset (handles train/val/test splits automatically)
data_module = PrunaDataModule.from_string(
"WikiText",
tokenizer=tokenizer,
collate_fn_args=dict(max_seq_len=256),
)
data_module.limit_datasets(32) # Limit to 32 examples per split for quick testing
smash_config.add_data(data_module)

# Configure quantizer parameters
smash_config.load_dict(
{
"quantizer": "ganq",
"ganq_weight_bits": 4,
"ganq_max_epoch": 10,
"ganq_pre_process": True,
}
)

# -------------------------------------------------------------------------
# 4. Run Quantization
# -------------------------------------------------------------------------
quantizer = GANQQuantizer()

quantized_model = quantizer._apply(model, smash_config)

# -------------------------------------------------------------------------
# 5. Save the quantized model
# -------------------------------------------------------------------------
quantized_model.save_pretrained("./ganq_quantized_smollm")
tokenizer.save_pretrained("./ganq_quantized_smollm")

print("✅ GANQ quantization complete and model saved at ./ganq_quantized_smollm")


def model_size_in_mb(model):
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
size_all_mb = (param_size + buffer_size) / 1024**2
return size_all_mb


original_size = model_size_in_mb(model)
quantized_size = model_size_in_mb(quantized_model)
print(f"Original model size: {original_size:.2f} MB")
print(f"Quantized model size: {quantized_size:.2f} MB")

```


**Verify if quantization worked**

The logic here is that since GANQ uses a codebook of size (m, L) for a weight matrix for size (m,n) where L is 2^k (k = number of bits), each row in the weight matrix W should only contain values from the corressponding row in the codebook, where selection is driven by the one hot matrix S. So number of unique values in each row of W should be exactly L.

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_name = "HuggingFaceTB/SmolLM2-135M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto"
)
model.eval()

model_q = AutoModelForCausalLM.from_pretrained(
"ganq_quantized_smollm"
)

def verify_unique_entries_in_row(layer, row_idx=0):
Wq = layer.self_attn.q_proj.weight.data
unique_entries = torch.unique(Wq[row_idx])
print(f"Number of unique entries in row {row_idx}: {unique_entries.numel()}")

verify_unique_entries_in_row(model_q.model.layers[1], row_idx=1)
verify_unique_entries_in_row(model.model.layers[1], row_idx=1)

# In my experiments, it gave this:
# Number of unique entries in row 1: 16 (since I used 4-bit quantization)
# Number of unique entries in row 1: 471
```
27 changes: 27 additions & 0 deletions src/pruna/algorithms/quantization/backends/ganq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ganq import GANQ
from .lut_quant import LUTQuant
from .utils import *

Check failure on line 16 in src/pruna/algorithms/quantization/backends/ganq/__init__.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (F403)

src/pruna/algorithms/quantization/backends/ganq/__init__.py:16:1: F403 `from .utils import *` used; unable to detect undefined names

__all__ = [
"GANQ",
"LUTQuant",
"init_t_3bit",

Check failure on line 21 in src/pruna/algorithms/quantization/backends/ganq/__init__.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (F405)

src/pruna/algorithms/quantization/backends/ganq/__init__.py:21:5: F405 `init_t_3bit` may be undefined, or defined from star imports
"init_t_4bit",

Check failure on line 22 in src/pruna/algorithms/quantization/backends/ganq/__init__.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (F405)

src/pruna/algorithms/quantization/backends/ganq/__init__.py:22:5: F405 `init_t_4bit` may be undefined, or defined from star imports
"normalize",

Check failure on line 23 in src/pruna/algorithms/quantization/backends/ganq/__init__.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (F405)

src/pruna/algorithms/quantization/backends/ganq/__init__.py:23:5: F405 `normalize` may be undefined, or defined from star imports
"denormalize",

Check failure on line 24 in src/pruna/algorithms/quantization/backends/ganq/__init__.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (F405)

src/pruna/algorithms/quantization/backends/ganq/__init__.py:24:5: F405 `denormalize` may be undefined, or defined from star imports
"norm_params",

Check failure on line 25 in src/pruna/algorithms/quantization/backends/ganq/__init__.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (F405)

src/pruna/algorithms/quantization/backends/ganq/__init__.py:25:5: F405 `norm_params` may be undefined, or defined from star imports
"find_layers",

Check failure on line 26 in src/pruna/algorithms/quantization/backends/ganq/__init__.py

View workflow job for this annotation

GitHub Actions / linting (3.10)

Ruff (F405)

src/pruna/algorithms/quantization/backends/ganq/__init__.py:26:5: F405 `find_layers` may be undefined, or defined from star imports
]
92 changes: 92 additions & 0 deletions src/pruna/algorithms/quantization/backends/ganq/ganq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# ruff: noqa: N806, N803, N802
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import transformers

from pruna.algorithms.quantization.backends.ganq.lut_quant import LUTQuant

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


class GANQ:
"""GANQ class for quantizing neural network layers."""

def __init__(self, layer, model_type):
self.layer = layer
self.dev = self.layer.weight.device
self.model_type = model_type
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.XXt = torch.zeros((self.columns, self.columns), device=self.dev)

def add_batch(self, inp, out):
"""Accumulate input statistics for quantization."""
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)

if isinstance(self.layer, (nn.Linear, transformers.Conv1D)):

# Note: Official implementation uses == 3 condition,
# refer here - https://github.com/Evans-Z/GANQ/blob/176a87701fd0e07aea1ccd4f3faff84871d79f44/ganq.py#L39
if len(inp.shape) > 2:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()

inp = inp.float()
self.XXt += inp @ inp.T

def fasterquant(
self, sparsity=0.0, bits=4, max_epoch=10, pre_process=True, full_rows=0
):
"""Main function to perform GANQ quantization."""
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()

quant = LUTQuant(
bits=bits,
W=W,
XXt=self.XXt,
max_epoch=max_epoch,
sparsity=sparsity,
model_type=self.model_type,
pre_process=pre_process,
full_rows=full_rows,
)
W = quant.quantization()

torch.cuda.synchronize()

if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.layer.weight.data = W.reshape(self.layer.weight.shape).to(
self.layer.weight.data.dtype
)

def free(self):
"""Free up memory."""
self.XXt = None
torch.cuda.empty_cache()
Loading
Loading