Skip to content

Commit ecee3aa

Browse files
committed
support qqq(w4a8) for lmdeploy
1 parent c685f77 commit ecee3aa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2891
-292
lines changed

lmdeploy/cli/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,10 @@ def model_format(parser, default: str = None):
106106
'--model-format',
107107
type=str,
108108
default=default,
109-
choices=['hf', 'llama', 'awq'],
109+
choices=['hf', 'llama', 'awq', 'qqq'],
110110
help='The format of input model. `hf` meaning `hf_llama`, `llama` '
111-
'meaning `meta_llama`, `awq` meaning the quantized model by awq')
111+
'meaning `meta_llama`, `awq` meaning the quantized model by awq, '
112+
'`qqq` meaning the quantized model by qqq')
112113

113114
@staticmethod
114115
def revision(parser, default: str = None):

lmdeploy/messages.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,9 @@ class TurbomindEngineConfig:
115115
"""TurboMind Engine config.
116116
117117
Args:
118-
model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq],
119-
`hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), awq` meaning the quantized model by AWQ.
118+
model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq, qqq],
119+
`hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), `awq` meaning the quantized model by AWQ,
120+
`qqq` meaning the quantized model by QQQ.
120121
tp (int): the number of GPU cards used in tensor parallelism, default to 1
121122
session_len (int): the max session length of a sequence, default to None
122123
max_batch_size (int): the max batch size during inference, default to 128

lmdeploy/turbomind/deploy/converter.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .source_model.base import INPUT_MODELS
1616
from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig
1717

18-
SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', None]
18+
SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', 'qqq', None]
1919
logger = get_logger('lmdeploy')
2020

2121

@@ -26,12 +26,14 @@ def get_input_model_registered_name(model_path: str, model_format: str):
2626
Args:
2727
model_path (str): the path of the input model
2828
model_format (str): the format of the model, which can be one of
29-
['meta_llama', 'hf', 'awq']
29+
['meta_llama', 'hf', 'awq', 'qqq']
3030
"""
3131
arch = get_model_arch(model_path)[0]
3232
register_name = SUPPORTED_ARCHS[arch]
3333
if model_format == 'awq':
3434
register_name = register_name + '-awq'
35+
elif model_format == 'qqq':
36+
register_name = register_name + '-qqq'
3537
return register_name
3638

3739

@@ -92,8 +94,9 @@ def get_output_model_registered_name_and_config(model_path: str,
9294
Args:
9395
model_path (str): the path of the input model
9496
model_format (str): the format of the model, which can be one of
95-
['meta_llama', 'hf', 'awq']
96-
group_size (int): the size of group used by awq model
97+
['meta_llama', 'hf', 'awq', 'qqq']
98+
group_size (int): the size of group used by quantization methods,
99+
including `awq` and `qqq`
97100
"""
98101
register_name = 'fp16'
99102
turbomind_model_arch = 'llama'
@@ -113,6 +116,15 @@ def get_output_model_registered_name_and_config(model_path: str,
113116
register_name = 'plora-w4' \
114117
if turbomind_model_arch == 'xcomposer2' else 'w4'
115118
group_size = 128 if group_size == 0 else group_size
119+
config.quantization = 'awq'
120+
elif model_format == 'qqq':
121+
weight_type = 'int4'
122+
register_name = 'qqq-w4'
123+
from transformers import AutoConfig
124+
quant_config = AutoConfig.from_pretrained(
125+
model_path).quantization_config
126+
group_size = quant_config['group_size']
127+
config.quantization = 'qqq'
116128
else:
117129
torch_dtype = getattr(model_config, 'torch_dtype', 'float16')
118130
TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16', torch.float16: 'fp16'}
@@ -212,17 +224,19 @@ def main(model_name: str,
212224
model_name (str): unused any longer
213225
model_path (str): the directory path of the model
214226
model_format (str): the format of the model, should choose from
215-
['meta_llama', 'hf', 'awq', None]. 'meta_llama' stands for META's
216-
llama format, 'hf' means huggingface llama format, and 'awq' means
217-
llama(hf) model quantized by lmdeploy/lite/quantization/awq.py.
218-
The default value is None
219-
chat_template (str): the name of the built-in chat template.
227+
['meta_llama', 'hf', 'awq', 'qqq', None]. 'meta_llama' stands for
228+
META's llama format, 'hf' means huggingface llama format,
229+
'awq' means llama(hf) model quantized by
230+
lmdeploy/lite/quantization/awq.py,
231+
and 'qqq' means llama(hf) model quantized by the repo
232+
https://github.com/HandH1998/QQQ,
233+
the default value is None
220234
tokenizer_path (str): the path of tokenizer model
221235
dst_path (str): the destination path that saves outputs
222236
tp (int): the number of GPUs used for tensor parallelism, should be 2^n
223237
quant_path (str): Path of the quantized model, which can be None.
224-
group_size (int): a parameter used in AWQ to quantize fp16 weights
225-
to 4 bits
238+
group_size (int): a parameter used in AWQ or QQQ to quantize fp16
239+
weights to 4 bits
226240
revision (str): The specific model version to use. It can be a branch
227241
name, a tag name, or a commit id. If unspecified, will use
228242
the default version.

lmdeploy/turbomind/deploy/source_model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .internvl import InternVLModel # noqa: F401
1010
from .llama import LlamaModel # noqa: F401
1111
from .llama_awq import LlamaAwqModel # noqa: F401
12+
from .llama_qqq import LlamaQQQModel # noqa: F401
1213
from .meta_llama import MetaLlamaModel # noqa: F401
1314
from .minicpmv import MiniCPMVModel # noqa: F401
1415
from .minicpmv_awq import MiniCPMVAwqModel # noqa: F401
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from .base import INPUT_MODELS
5+
from .llama import LlamaModel, LlamaReader
6+
7+
8+
def ensure_dtype(tensors: torch.Tensor, dtype: torch.dtype):
9+
"""Ensure tensors in the specified dytpe."""
10+
result = []
11+
for tensor in tensors:
12+
if tensor is not None and tensor.numel() > 0:
13+
if tensor.dtype in [torch.float16, torch.float32, torch.bfloat16]:
14+
result.append(tensor.to(dtype))
15+
else:
16+
assert tensor.dtype == torch.int32
17+
result.append(tensor)
18+
else:
19+
result.append(None)
20+
return (*result, )
21+
22+
23+
class LlamaQQQReader(LlamaReader):
24+
"""LlamaQQQReader."""
25+
26+
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool,
27+
model_cfg: dict):
28+
super().__init__(new_params, unused_params, last_bin, model_cfg)
29+
30+
def attn(self, i: int):
31+
"""Get q, k, v, o qweight for layer i."""
32+
return ensure_dtype(self._attn(i, 'B'), torch.int32)
33+
34+
def attn_scale_group(self, i: int):
35+
"""Get q, k, v, o per-group scales for layer i."""
36+
return ensure_dtype(self._attn(i, 's_group'), torch.float16)
37+
38+
def attn_scale_channel(self, i: int):
39+
"""Get q, k, v, o per-channel scales for layer i."""
40+
return ensure_dtype(self._attn(i, 's_channel'), torch.float32)
41+
42+
def ffn(self, i: int):
43+
"""Get ffn qweight for layer i."""
44+
return ensure_dtype(self._ffn(i, 'B'), torch.int32)
45+
46+
def ffn_scale_group(self, i: int):
47+
"""Get ffn per-group scales for layer i."""
48+
return ensure_dtype(self._ffn(i, 's_group'), torch.float16)
49+
50+
def ffn_scale_channel(self, i: int):
51+
"""Get ffn per-channel scales for layer i."""
52+
return ensure_dtype(self._ffn(i, 's_channel'), torch.float32)
53+
54+
55+
@INPUT_MODELS.register_module(name='llama-qqq')
56+
class LlamaQQQModel(LlamaModel):
57+
"""Llama QQQ model in hf format."""
58+
59+
Reader = LlamaQQQReader
60+
61+
def __init__(self,
62+
model_path: str,
63+
tokenizer_path: str,
64+
ckpt_path: str = None,
65+
**kwargs):
66+
super().__init__(model_path,
67+
tokenizer_path,
68+
ckpt_path=ckpt_path,
69+
**kwargs)

lmdeploy/turbomind/deploy/target_model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from .fp import TurbomindModel # noqa: F401
33
from .plora import TurbomindPloraModel # noqa: F401
44
from .plora_w4 import TurbomindPloraW4Model # noqa: F401
5+
from .qqq_w4 import TurbomindQQQW4Model # noqa: F401
56
from .w4 import TurbomindW4Model # noqa: F401

lmdeploy/turbomind/deploy/target_model/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class TurbomindModelConfig:
6666
max_prefill_iters: int = 1
6767
use_context_fmha: int = 1
6868
quant_policy: int = 0
69+
quantization: str = ''
6970
max_position_embeddings: int = 0
7071
original_max_position_embeddings: int = 0
7172
rope_scaling_type: str = ''

0 commit comments

Comments
 (0)