Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
Merged
Changes from all commits
Commits
Show all changes
138 commits
Select commit Hold shift + click to select a range
d8884e6
Refactor BatchMatMulEmitter and BatchMatMulSelector for improved read…
LeiWang1999 Jul 5, 2024
fc84173
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
02f64de
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
397eee6
disable failure email for ci
LeiWang1999 Jul 5, 2024
20f6ad1
remove email notifications.
LeiWang1999 Jul 6, 2024
b93c394
move relax pass from testing to mlc_llm
LeiWang1999 Jul 6, 2024
ba6a6df
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
257693a
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
9bb7f49
Lint Fix
LeiWang1999 Jul 6, 2024
39e7614
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
93eb5a5
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
aa66a90
bug fix in test
LeiWang1999 Jul 6, 2024
ae14a53
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 6, 2024
79b08e4
lint fix.
LeiWang1999 Jul 6, 2024
86fd036
test cuda i4 kernel
LeiWang1999 Jul 7, 2024
6b73a21
Refactor copyright notice in i4matmul.hpp
LeiWang1999 Jul 7, 2024
0ba90c1
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 7, 2024
086d208
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 7, 2024
47a3abd
refactor test as version below python 3.9 cannot handle int32 overflow.
LeiWang1999 Jul 8, 2024
024b247
format lint for test
LeiWang1999 Jul 8, 2024
bfedeaa
Refactor test_int4b_fp16_convert.py for improved readability and main…
LeiWang1999 Jul 8, 2024
e672a23
remove unused design file
LeiWang1999 Jul 8, 2024
21e5430
move tile device from package to base
LeiWang1999 Jul 8, 2024
fd11940
dummy impl for codegen
LeiWang1999 Jul 8, 2024
9ccfa85
Refactor file structure for ladder_permutate module
LeiWang1999 Jul 8, 2024
7c7d73e
Refactor backend class and fix typos in comments
LeiWang1999 Jul 8, 2024
47d5fc5
Deep refactor Lib related code.
LeiWang1999 Jul 8, 2024
53dd0dd
remove ci pull.
LeiWang1999 Jul 10, 2024
d58ac43
LintFix
LeiWang1999 Jul 10, 2024
37cb07c
refactor builder for whl build
LeiWang1999 Jul 10, 2024
f5b9999
Refactor TIRWrapper.wrap() method to include an assertion for the opt…
LeiWang1999 Jul 11, 2024
fb78244
Refactor lib_generator to set library and source paths
LeiWang1999 Jul 11, 2024
706e227
lint fix
LeiWang1999 Jul 11, 2024
63f5515
BitNet vllm integration
LeiWang1999 Jul 16, 2024
de91c0d
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 16, 2024
b9655fd
chore: update codespell to version 2.3.0
LeiWang1999 Jul 16, 2024
fff385f
Lintfix
LeiWang1999 Jul 16, 2024
72a98e7
Bump version to 0.0.1.dev13
LeiWang1999 Jul 18, 2024
5646ab5
lint fix
LeiWang1999 Jul 18, 2024
b965863
disable fast decoding [u]int4xint8 by default.
LeiWang1999 Jul 21, 2024
1198fc7
optimize from dict design in Hint
LeiWang1999 Jul 21, 2024
014213c
Implement SplitK
LeiWang1999 Jul 21, 2024
e0ca752
bitnet benchmark generation.
LeiWang1999 Jul 21, 2024
81b9cf0
Add benchmark script for BitNet integration
LeiWang1999 Jul 21, 2024
02edc0b
AtomicAdd Support
LeiWang1999 Jul 21, 2024
1a70c2d
LintFix
LeiWang1999 Jul 21, 2024
28d851c
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 21, 2024
c447a95
ci fix when 3rdparty tvm is initialized.
LeiWang1999 Jul 21, 2024
79a001b
bug fix for setup
LeiWang1999 Jul 21, 2024
31813b2
fix a bug in block reduce
LeiWang1999 Jul 21, 2024
78b6a3d
typo fix
LeiWang1999 Jul 21, 2024
9c55218
BUG Fix for block reduce.
LeiWang1999 Jul 22, 2024
1aa8868
Lint fix
LeiWang1999 Jul 22, 2024
22f70bf
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
5f082a5
Refactor block reduce schedule template
LeiWang1999 Jul 22, 2024
b4fb31e
transform branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
35eaa00
Fix subproject commit reference in 3rdparty/tvm
LeiWang1999 Jul 22, 2024
254dd74
chore: update submodule branch from bitblas to bitblas_tl
LeiWang1999 Jul 22, 2024
31a44aa
force update config.cmake
LeiWang1999 Jul 22, 2024
427800e
Bug fix
LeiWang1999 Jul 22, 2024
96db111
Fix subproject commit reference in 3rdparty/cutlass
LeiWang1999 Jul 22, 2024
38b251a
chore: Add submodule for cutlass library
LeiWang1999 Jul 22, 2024
87d1c5a
update tl cutlass path
LeiWang1999 Jul 22, 2024
6200b1e
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 22, 2024
0ffe0b5
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 22, 2024
8e08e77
format fix
LeiWang1999 Jul 22, 2024
df05a64
Copy CUTLASS to the package directory
LeiWang1999 Jul 22, 2024
4f529c5
Refactor setup.py to include additional TVM header files
LeiWang1999 Jul 22, 2024
d02bbc7
lint fix
LeiWang1999 Jul 23, 2024
cffe3fd
bug fix
LeiWang1999 Jul 23, 2024
a8bed74
Refactor BitBLASLinear test module for improved readability and maint…
LeiWang1999 Jul 23, 2024
d4eb5fd
Implement Matmul Benchmark Design
LeiWang1999 Jul 23, 2024
4c6c2c1
chore: Update BitBLAS Matmul benchmark script
LeiWang1999 Jul 23, 2024
0acaca1
lint fix
LeiWang1999 Jul 23, 2024
54d2227
Refactor BitBLASMatmulOpsBenchmark for improved readability and maint…
LeiWang1999 Jul 23, 2024
c2edefb
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
e0bc723
lint fix
LeiWang1999 Jul 23, 2024
a4e68d1
Benchmark bot test
LeiWang1999 Jul 23, 2024
df7e9aa
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
1c03365
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
4f319fc
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
a8833d4
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
803f6c6
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
df4572b
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
45ded45
int8 test case
LeiWang1999 Jul 23, 2024
4229676
Refactor compare_benchmark.py to handle missing benchmark results gra…
LeiWang1999 Jul 23, 2024
b883290
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
476ffee
ci fix
LeiWang1999 Jul 23, 2024
9bd34ff
disable ci for test benchmark
LeiWang1999 Jul 23, 2024
e86f4b2
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 23, 2024
75f3dd9
Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark…
LeiWang1999 Jul 23, 2024
79e04aa
remove cli installation
LeiWang1999 Jul 23, 2024
cdd3345
chore: Create virtual environment and install dependencies for benchmark
LeiWang1999 Jul 23, 2024
f099938
Merge branch 'main' into dev
LeiWang1999 Jul 23, 2024
f211ad4
chore: Update benchmark workflow to include comparison step
LeiWang1999 Jul 23, 2024
ddde02a
Lint fix
LeiWang1999 Jul 24, 2024
8045ce9
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 24, 2024
21aee89
Merge branch 'dev' of https://github.com/LeiWang1999/MSBitBLAS into dev
LeiWang1999 Jul 24, 2024
ef1b158
upodate tvm cmmit
LeiWang1999 Jul 25, 2024
a8d8841
Imporve lower warp memory pass
LeiWang1999 Jul 30, 2024
686b929
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 30, 2024
7736c38
Bug fix
LeiWang1999 Jul 30, 2024
199affc
Enhance to support warp schedule.
LeiWang1999 Jul 31, 2024
9d0c25d
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 31, 2024
7c1f52e
Enhance LOP3 Instructions
LeiWang1999 Jul 31, 2024
d1b2bc7
Enhance LOP3 Instructions
LeiWang1999 Jul 31, 2024
2aac6d0
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Jul 31, 2024
802abde
add test for stage3 propagate
LeiWang1999 Jul 31, 2024
d339037
implement propagate func
LeiWang1999 Jul 31, 2024
0f6a033
Stage3 Ladder Permutate integration
LeiWang1999 Jul 31, 2024
00ec916
get_ladder_stage3_propagate
LeiWang1999 Jul 31, 2024
5316577
comments benchmark scirpts as the setting is too big
LeiWang1999 Jul 31, 2024
dd070f9
ci fix for benchmark
LeiWang1999 Jul 31, 2024
6fcc368
lint fix
LeiWang1999 Jul 31, 2024
705580b
chore: Update benchmark workflow to trigger on pull request comments
LeiWang1999 Jul 31, 2024
c5ba940
Add LDMatrix Transform 3
LeiWang1999 Aug 1, 2024
1566990
Support GPTQ Test
LeiWang1999 Aug 1, 2024
c6c70ef
Fuse BlockReduce Schedule
LeiWang1999 Aug 1, 2024
36128f3
Support mma propagate 3
LeiWang1999 Aug 1, 2024
23ff5f4
Support MMA Propagate Stage 3
LeiWang1999 Aug 1, 2024
de3bf08
Lint Fix
LeiWang1999 Aug 1, 2024
d9830ba
Merge block reduce for dequantze config.
LeiWang1999 Aug 1, 2024
e5a4485
fix codeql
LeiWang1999 Aug 2, 2024
a04282b
chore: Update submodule reference to latest commit
LeiWang1999 Aug 4, 2024
314d3e9
chore: Disable common subexpression elimination in TIR passes
LeiWang1999 Aug 4, 2024
f7d33bb
Lint Fix
LeiWang1999 Aug 4, 2024
db633ed
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Aug 4, 2024
201155a
4bit related lop3 updates.
LeiWang1999 Aug 4, 2024
2b73662
lint fix
LeiWang1999 Aug 4, 2024
1a6a0fd
gptq test fix
LeiWang1999 Aug 4, 2024
e84e3ef
Fix for test
LeiWang1999 Aug 4, 2024
f0fbb55
lint fix
LeiWang1999 Aug 4, 2024
bf30688
lint fix
LeiWang1999 Aug 4, 2024
9a360ba
typofix
LeiWang1999 Aug 4, 2024
ee94536
QuantCompress Test
LeiWang1999 Aug 5, 2024
930cd76
chore: Refactor quant_compress_impl.py for readability and maintainab…
LeiWang1999 Aug 5, 2024
8c24776
Enhance docs to update latest works.
LeiWang1999 Aug 5, 2024
c018e3c
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into dev
LeiWang1999 Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 305 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ Some of the key features of BitBLAS include:

## Latest News

- 2024.04.19: BitBLAS is now open source! We are excited to announce that BitBLAS, a high-performance library for mixed-precision DNN model deployment, is now available to the public.
- 2024.04.30: BitBLAS now supports FP8 TensorCore!
- 04/19/2024 ✨: We are excited to announce that BitBLAS, a high-performance library for mixed-precision DNN model deployment, is now open source and available to the public!
- 04/30/2024 🚀🚀: BitBLAS now supports FP8 TensorCore (E5M2/E4M3 * E4M3/E5M2), providing more combinations beyond the three available in cuBLAS!
- 05/04/2024 🚀🚀: We’ve added integration examples for the 1.58-bit model! Check out the files under integration/BitNet.
- 06/25/2024 🚀🚀: BitBLAS has been integrated into GPTQModel! You can now use BitBLAS as a backend in GPTQ.

## Integration Example of FasterTransformer with BitBLAS
![FasterTransformer Integration](images/gif/FasterTransformer.gif)


## Benchmark Summary

BitBLAS achieves exceptional performance across a variety of computational patterns. Below are selected results showcasing its capabilities:
Expand Down Expand Up @@ -74,15 +75,311 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and

We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR.

## Getting Started
## Installation Guide

### Prerequisites

**Operating System**: Linux (Ubuntu 20.04 or later recommended for installation via wheel or PyPI or you may need to checkout the [Building from Source](#building-from-source) section for other Linux distributions.)
- **Python Version**: >= 3.7
- **CUDA Version**: >= 10.0

### Installing with pip

- [Installation](https://github.com/microsoft/BitBLAS/blob/main/docs/Installation.md):
To install BitBLAS, please checkout the document [installation](https://github.com/microsoft/BitBLAS/blob/main/docs/Installation.md). Also Make sure you already have the cuda toolkit (version >= 11) installed in the system. Or you can easily install from `pip install bitblas` from PyPi. Currently we only provide whl files for CUDA>=12.1 and Ubuntu>=20.04 with Python>=3.8, if you are using a different version of CUDA or OS System, you may need to build BitBLAS from source.
The easiest way to install BitBLAS is direcly from the PyPi using pip. To install the latest version, run the following command in your terminal.

**Note**: Currently, bitblas whl is only supported on Linux systems. We recommend using Ubuntu 20.04 or later version as we build the whl files on this platform. Currently we only provide whl files for CUDA>=12.1 and with Python>=3.8. If you are using a different version of CUDA. you may need to build BitBLAS from source.

```bash
pip install bitblas
```

Alternatively, you may choose to install BitBLAS using prebuilt packages available on the Release Page:

```bash
pip install bitblas-0.0.0.dev0+ubuntu.20.4.cu120-py3-none-any.whl
```

After installing BitBLAS, you can verify the installation by running:

```bash
python -c "import bitblas; print(bitblas.__version__)"
```

- [QuickStart](https://github.com/microsoft/BitBLAS/blob/main/docs/QuickStart.md): BitBLAS provides two Python APIs to perform mixed-precision matrix multiplication:
- ```bitblas.Matmul``` implements the $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication of $C_{cdtype}[M, N] = A_{adtype}[M, K] \times W_{wdtype}[N, K]$.
### Building from Source

We recommend using a docker container with the necessary dependencies to build BitBLAS from source. You can use the following command to run a docker container with the necessary dependencies:

```bash
docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.01-py3
```

To build and install BitBLAS directly from source, follow the steps below. This process requires certain pre-requisites from apache tvm, which can be installed on Ubuntu/Debian-based systems using the following commands:

```bash
sudo apt-get update
sudo apt-get install -y python3 python3-dev python3-setuptools gcc libtinfo-dev zlib1g-dev build-essential cmake libedit-dev libxml2-dev
```

After installing the prerequisites, you can clone the BitBLAS repository and install it using pip:

```bash
git clone --recursive https://github.com/Microsoft/BitBLAS.git
cd BitBLAS
pip install . # Please be patient, this may take some time.
```

if you want to install BitBLAS with the development mode, you can run the following command:

```bash
pip install -e .
```

## Quick Start

BitBLAS provides two Python APIs to perform mixed-precision matrix multiplication:
- ```bitblas.Matmul``` implements the $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication of $C_{cdtype}[M, N] = A_{adtype}[M, K] \times W_{wdtype}[N, K]$ where $W_{wdtype}$ indicates the weight of $wtype$, A_{adtype} indicates the activation of $adtype$, and C_{cdtype} indicates the output of $cdtype$.
- ```bitblas.Linear``` is a PyTorch ```nn.Linear```-like module to support a Linear of mixed-precision.

### Example: $W_{INT4}A_{FP16}$ mixed-precision matrix multiplication

Here is an example for a $W_{INT4}A_{FP16}$ mixed-precision matrix multiplication: $out_{FP16}[M, N] = A_{FP16}[M, K] \times W_{INT4}[N, K]$, the example includes the creation of input matrices, quantization of weight matrices, and execution of the multiplication. The result is then compared against a reference result obtained through conventional methods to ensure accuracy.

```python
import bitblas
import torch

# enabling debug output

bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=1024, # N dimension
K=1024, # K dimension
A_dtype="float16", # activation A dtype
W_dtype="int4", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Create input matrices
input_tensor = torch.rand((1, 1024), dtype=torch.float16).cuda()
weight_tensor = torch.randint(0, 7, (1024, 1024), dtype=torch.int8).cuda()

# Transform weight tensor to int4 data type
weight_tensor_int4 = matmul.transform_weight(weight_tensor)

# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int4)

# Reference result using PyTorch matmul for comparison
ref_result = torch.matmul(input_tensor, weight_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance, note that the int4 randint value is a little bigger than the float16 value, so we set the atol to 1.0
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)
```

The same example can be extended to include the quantization of the weight tensor with scaling and zeros. The following code snippet demonstrates how to quantize the weight tensor with scaling and zeros and execute the mixed-precision matrix multiplication.

```python
import bitblas
import torch

in_features = 1024
out_features = 1024
group_size = 128

matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=out_features, # N dimension
K=in_features, # K dimension
A_dtype="float16", # activation A dtype
W_dtype="uint4", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=group_size, # setting for grouped quantization
with_scaling=True, # setting for scaling factor
with_zeros=True, # setting for zeros
zeros_mode="original", # setting for how to calculating zeros
)
matmul = bitblas.Matmul(config=matmul_config)

# Define shapes for tensors
input_shape = (1, 1024)
weight_shape = (1024, 1024)
scaling_shape = (1024, 1024 // 128)
zeros_shape = (1024, 1024 // 128)
output_shape = (1, 1024)

# Create scaling and zeros tensors for quantization
scaling = torch.rand(scaling_shape, dtype=torch.float16).cuda()
zeros = torch.rand(zeros_shape, dtype=torch.float16).cuda()

# Create input tensor
input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda()

# Create and transform weight tensor
weight_tensor = torch.randint(0, 7, weight_shape, dtype=torch.int8).cuda()
weight_tensor_int4 = matmul.transform_weight(weight_tensor)

# Perform mixed-precision matrix multiplication with quantization
output_tensor = matmul(input_tensor, weight_tensor_int4, scale=scaling, zeros=zeros)

rescaling_tensor = torch.zeros_like(weight_tensor, dtype=torch.float16).cuda()
# Compute reference result with manual scaling and zero-point adjustment
# rescale = (weight - zeros) * scaling
for i in range(in_features // group_size):
for j in range(group_size):
rescaling_tensor[:, i * group_size + j] = (
weight_tensor[:, i * group_size + j].to(torch.float16) - zeros[:, i]
) * scaling[:, i]
ref_result = torch.matmul(input_tensor, rescaling_tensor.t().to(torch.float16))
# Assert that the results are close within a specified tolerance
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-2)
```

The init stage of the ```bitblas.Matmul``` class will take minutes to finish, as it will use hardware informations to do a one-time kernel library initialization.

### Example: bitblas.Linear module for PyTorch

BitBLAS also implemented a variant PyTorch ```nn.Linear``` module, i.e., ```bitblas.Linear```, to support a Linear of mixed-precision. See code [implementation](../python/bitblas/module/__init__.py)

Here is an example to define a ```bitblas.Linear``` of $W_{INT4}A_{FP16}$:

```python
import bitblas
import torch

# enabling debug output
bitblas.set_log_level("Debug")

model = bitblas.Linear(
in_features=1024,
out_features=1024,
bias=False,
A_dtype="float16", # activation A dtype
W_dtype="int4", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
# Target optimization var for dynamic symbolic.
# For detailed information please checkout docs/PythonAPI.md
# By default, the optimization var is [1, 16, 32, 64, 128, 256, 512]
opt_M=[1, 16, 32, 64, 128],
)

# Create an integer weight tensor
intweight = torch.randint(-7, 7, (1024, 1024), dtype=torch.int8)

# Load and transform weights into the BitBLAS linear module
model.load_and_transform_weight(intweight)

# Save the state of the model
torch.save(model.state_dict(), "./model.pth")

# Load the model state
model.load_state_dict(torch.load("./model.pth"))

# Set the model to evaluation mode
model.eval()

# Create a dummy input tensor
dummpy_input = torch.randn(1, 1024, dtype=torch.float16)

# Perform inference
output = model(dummpy_input)
print("BitBLAS output:", output)
# Please checkout the correctness evaluation code in `testing/python/module/test_bitblas_linear.py`
```

we also provide repack interface to repack the pretrained weight of AutoGPTQ into the format of BitBLAS. Here is an example to repack the pretrained weight of AutoGPTQ:

```python
# !pip install auto-gptq
import bitblas
import torch
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import (
QuantLinear as CudaOldQuantLinear,
)

# enabling debug output
bitblas.set_log_level("Debug")

in_features = 1024
out_features = 1024
group_size = 128

original_w, linear, s, qw = bitblas.quantization.gen_quant4(
in_features, out_features, group_size
)
zeros = torch.full((in_features // group_size, out_features), 7, dtype=torch.int32)

cuda_old_linear = CudaOldQuantLinear(
bits=4,
group_size=group_size,
infeatures=in_features,
outfeatures=out_features,
bias=False,
)
cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None)

bitblas_linear = bitblas.Linear(
in_features=in_features,
out_features=out_features,
bias=False,
A_dtype="float16", # activation A dtype
W_dtype="uint4", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
# configs for weight only quantization
group_size=group_size, # setting for grouped quantization
with_scaling=True, # setting for scaling factor
with_zeros=True, # setting for zeros
zeros_mode="quantized", # setting for how to calculating zeros
)
# Repack weights from CudaOldQuantLinear to BitBLAS linear module
bitblas_linear.repack_from_gptq(cuda_old_linear)

# Prepare input data
m = 1 # Batch size
inp = torch.rand(m, in_features, dtype=torch.float16, device="cuda")

# Move models to CUDA for execution
cuda_old_linear = cuda_old_linear.to("cuda")
bitblas_linear = bitblas_linear.to("cuda")

# Perform inference without gradient calculations
with torch.no_grad():
res_cuda_old = cuda_old_linear(inp)
res_bitblas = bitblas_linear(inp)

print("CudaOldQuantLinear output:", res_cuda_old)
print("BitBLAS output:", res_bitblas)

# Verify the outputs are close within specified tolerances
torch.testing.assert_close(res_bitblas, res_cuda_old, rtol=1e-0, atol=1e-1)
```

## Other Documents

- [Python API](https://github.com/microsoft/BitBLAS/blob/main/docs/PythonAPI.md): The Python API doc of BitBLAS.

- [Integration](https://github.com/microsoft/BitBLAS/tree/main/integration): Explore how BitBLAS seamlessly integrates with LLM deployment frameworks through our examples. Discover the ease of integrating BitBLAS with PyTorch, AutoGPTQ, and vLLM in the 3rd-party integration examples.
Expand Down