Skip to content

Commit ec43205

Browse files
authored
Support AWQ format models (#1350)
* Support AWQ format models * Clippy fix
1 parent c116ce4 commit ec43205

File tree

19 files changed

+1786
-494
lines changed

19 files changed

+1786
-494
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis
6363
./mistralrs-server -i --isq 4 plain -m Qwen/Qwen3-8B
6464
```
6565
66+
- Run the **AWQ format** models
67+
Step 1: Convert AWQ model to marlin compatible format
68+
```
69+
python3 scripts/convert_awq_marlin.py --src /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4/ --dst /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4-Marlin/ --bits 4
70+
```
71+
Step 2: Run the converted model
72+
```
73+
./mistralrs-server -i plain -m /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4-Marlin/
74+
```
6675
6776
- 💎💎💎 Run the entire **Gemma 3** Model family (1b, 4b, 12b, 27b) with 128k context length and vision support: [documentation](docs/GEMMA3.md)
6877
@@ -152,6 +161,7 @@ Mistral.rs supports several model categories:
152161
- [Details](docs/QUANTS.md)
153162
- GGML: 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit, with imatrix support
154163
- GPTQ: 2-bit, 3-bit, 4-bit and 8-bit, with [Marlin](https://github.com/IST-DASLab/marlin) kernel support in 4-bit and 8-bit.
164+
- AWQ: 4-bit and 8-bit (convert using [script](scripts/convert_awq_marlin.py))
155165
- AFQ: 🔥 2-bit, 3-bit, 4-bit, 6-bit and 8-bit, designed to be fast on Metal!
156166
- HQQ: 4-bit and 8 bit, with ISQ support
157167
- FP8

mistralrs-core/src/speech_models/dia/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,10 @@ impl DiaPipeline {
458458
if let Some(eos_countdown) = &mut eos_countdown {
459459
let step_after_eos = max_delay_pattern - *eos_countdown;
460460
for (i, d) in delay_pattern.iter().enumerate() {
461-
if step_after_eos == *d as usize {
462-
pred_c[i] = audio_eos_value;
463-
} else if step_after_eos > *d as usize {
464-
pred_c[i] = audio_pad_value;
461+
match step_after_eos.cmp(&(*d as usize)) {
462+
std::cmp::Ordering::Equal => pred_c[i] = audio_eos_value,
463+
std::cmp::Ordering::Greater => pred_c[i] = audio_pad_value,
464+
std::cmp::Ordering::Less => {}
465465
}
466466
}
467467
*eos_countdown -= 1;

mistralrs-quant/kernels/marlin/marlin/marlin.cuh

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@ namespace marlin {
1818
// than 1 warp per schedule allows some more latency hiding. At the same time,
1919
// we want relatively few warps to have many registers per warp and small tiles.
2020

21+
static constexpr int default_threads = 256;
2122
static constexpr int repack_threads = 256;
2223
static constexpr int repack_stages = 8;
2324
static constexpr int min_thread_n = 64;
2425
static constexpr int min_thread_k = 64;
25-
26+
static constexpr int max_thread_n = 256;
2627
static constexpr int tile_size = 16;
2728
static constexpr int max_par = 16;
2829
static constexpr int tile_k_size = tile_size;
2930
static constexpr int tile_n_size = tile_k_size * 4;
31+
static constexpr int pipe_stages = 4;
3032

31-
__device__ inline constexpr int ceildiv(int a, int b) {
32-
return (a + b - 1) / b;
33-
}
33+
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
3434

3535
// Predicated asynchronous global->shared copy; used for inputs A where we apply
3636
// predication to handle batchsizes that are not multiples of 16.
@@ -115,4 +115,13 @@ struct Vec {
115115

116116
using I4 = Vec<int, 4>;
117117

118+
enum ScalarTypeID {
119+
//gptq
120+
kU4B8,
121+
kU8B128,
122+
//awq
123+
kU4,
124+
kU8,
125+
};
126+
118127
} // namespace marlin

0 commit comments

Comments
 (0)