Skip to content

Commit 9929e7c

Browse files
committed
Load INC GPTQ checkpoint & rename params #1364
1 parent b87d80e commit 9929e7c

4 files changed

Lines changed: 156 additions & 15 deletions

File tree

examples/text-generation/README.md

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ python run_generation.py \
486486

487487
### Loading 4 Bit Checkpoints from Hugging Face
488488

489-
You can load pre-quantized 4bit models with the argument `--load_quantized_model`.
489+
You can load pre-quantized 4bit models with the argument `--load_quantized_model_with_inc`.
490490
Currently, uint4 checkpoints and single device are supported.
491491
More information on enabling 4 bit inference in SynapseAI is available here:
492492
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_UINT4.html.
@@ -508,7 +508,35 @@ python run_lm_eval.py \
508508
--attn_softmax_bf16 \
509509
--bucket_size=128 \
510510
--bucket_internal \
511-
--load_quantized_model
511+
--load_quantized_model_with_inc
512+
```
513+
514+
### Loading 4 Bit Checkpoints from Neural Compressor (INC)
515+
516+
You can load a pre-quantized 4-bit checkpoint with the argument `--quantized_inc_model_path`, supplied with the original model with the argument `--model_name_or_path`.
517+
Currently, only uint4 checkpoints and single-device configurations are supported.
518+
**Note:** In this process, you can load a checkpoint that has been quantized using INC.
519+
More information on enabling 4-bit inference in SynapseAI is available here:
520+
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_INT4.html.
521+
522+
Below is an example of loading a llama7b model with a 4bit checkpoint quantized in INC.
523+
Please note that the model checkpoint name is denoted as `<local_model_path_from_inc>`.
524+
Additionally, the following environment variables are used for performance optimizations and are planned to be removed in future versions:
525+
`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1`
526+
```bash
527+
SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=1 \
528+
python run_lm_eval.py \
529+
-o acc_load_uint4_model.txt \
530+
--model_name_or_path meta-llama/Llama-2-7b-hf \
531+
--use_hpu_graphs \
532+
--use_kv_cache \
533+
--trim_logits \
534+
--batch_size 1 \
535+
--bf16 \
536+
--attn_softmax_bf16 \
537+
--bucket_size=128 \
538+
--bucket_internal \
539+
--quantized_inc_model_path <local_model_path_from_inc> \
512540
```
513541

514542
### Using Habana Flash Attention
@@ -539,6 +567,37 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
539567

540568
For more details see [documentation](https://docs.habana.ai/en/latest/PyTorch/Model_Optimization_PyTorch/Optimization_in_PyTorch_Models.html#using-fused-sdpa).
541569

570+
### Running with UINT4 weight quantization using AutoGPTQ
571+
572+
573+
Llama2-7b in UINT4 weight only quantization is enabled using [AutoGPTQ Fork](https://github.com/HabanaAI/AutoGPTQ), which provides quantization capabilities in PyTorch.
574+
Currently, the support is for UINT4 inference of pre-quantized models only.
575+
576+
You can run a *UINT4 weight quantized* model using AutoGPTQ by setting the following environment variables:
577+
`SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=true` before running the command,
578+
and by adding the argument `--load_quantized_model_with_autogptq`.
579+
580+
***Note:***
581+
Setting the above environment variables improves performance. These variables will be removed in future releases.
582+
583+
584+
Here is an example to run a quantized model <quantized_gptq_model>:
585+
```bash
586+
SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false \
587+
ENABLE_EXPERIMENTAL_FLAGS=true python run_generation.py \
588+
--attn_softmax_bf16 \
589+
--model_name_or_path <quantized_gptq_model> \
590+
--use_hpu_graphs \
591+
--limit_hpu_graphs \
592+
--use_kv_cache \
593+
--bucket_size 128 \
594+
--bucket_internal \
595+
--trim_logits \
596+
--max_new_tokens 128 \
597+
--batch_size 1 \
598+
--bf16 \
599+
--load_quantized_model_with_autogptq
600+
```
542601

543602
## Language Model Evaluation Harness
544603

examples/text-generation/run_generation.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -293,21 +293,11 @@ def setup_parser(parser):
293293
type=str,
294294
help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.",
295295
)
296-
parser.add_argument(
297-
"--disk_offload",
298-
action="store_true",
299-
help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.",
300-
)
301296
parser.add_argument(
302297
"--trust_remote_code",
303298
action="store_true",
304299
help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.",
305300
)
306-
parser.add_argument(
307-
"--load_quantized_model",
308-
action="store_true",
309-
help="Whether to load model from hugging face checkpoint.",
310-
)
311301
parser.add_argument(
312302
"--parallel_strategy",
313303
type=str,
@@ -321,6 +311,35 @@ def setup_parser(parser):
321311
help="Whether to enable inputs_embeds or not.",
322312
)
323313

314+
parser.add_argument(
315+
"--run_partial_dataset",
316+
action="store_true",
317+
help="Run the inference with dataset for specified --n_iterations(default:5)",
318+
)
319+
320+
quant_parser_group = parser.add_mutually_exclusive_group()
321+
quant_parser_group.add_argument(
322+
"--load_quantized_model_with_autogptq",
323+
action="store_true",
324+
help="Load an AutoGPTQ quantized checkpoint using AutoGPTQ.",
325+
)
326+
quant_parser_group.add_argument(
327+
"--disk_offload",
328+
action="store_true",
329+
help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.",
330+
)
331+
quant_parser_group.add_argument(
332+
"--load_quantized_model_with_inc",
333+
action="store_true",
334+
help="Load a Huggingface quantized checkpoint using INC.",
335+
)
336+
quant_parser_group.add_argument(
337+
"--quantized_inc_model_path",
338+
type=str,
339+
default=None,
340+
help="Path to neural-compressor quantized model, if set, the checkpoint will be loaded.",
341+
)
342+
324343
args = parser.parse_args()
325344

326345
if args.torch_compile:
@@ -333,6 +352,9 @@ def setup_parser(parser):
333352
args.flash_attention_fast_softmax = True
334353

335354
args.quant_config = os.getenv("QUANT_CONFIG", "")
355+
if args.quant_config and args.load_quantized_model_with_autogptq:
356+
raise RuntimeError("Setting both quant_config and load_quantized_model_with_autogptq is unsupported. ")
357+
336358
if args.quant_config == "" and args.disk_offload:
337359
logger.warning(
338360
"`--disk_offload` was tested only with fp8, it may not work with full precision. If error raises try to remove the --disk_offload flag."

examples/text-generation/utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,34 @@ def setup_model(args, model_dtype, model_kwargs, logger):
237237
torch_dtype=model_dtype,
238238
**model_kwargs,
239239
)
240-
elif args.load_quantized_model:
240+
elif args.load_quantized_model_with_autogptq:
241+
from transformers import GPTQConfig
242+
243+
quantization_config = GPTQConfig(bits=4, use_exllama=False)
244+
model = AutoModelForCausalLM.from_pretrained(
245+
args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs
246+
)
247+
elif args.load_quantized_model_with_inc:
241248
from neural_compressor.torch.quantization import load
242249

243250
model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs)
251+
elif args.quantized_inc_model_path:
252+
org_model = AutoModelForCausalLM.from_pretrained(
253+
args.model_name_or_path,
254+
**model_kwargs,
255+
)
256+
257+
from neural_compressor.torch.quantization import load
258+
259+
model = load(
260+
model_name_or_path=args.quantized_inc_model_path,
261+
format="default",
262+
device="hpu",
263+
original_model=org_model,
264+
**model_kwargs,
265+
)
266+
# TODO: [SW-195965] Remove once load supports other types
267+
model = model.to(model_dtype)
244268
else:
245269
if args.assistant_model is not None:
246270
assistant_model = AutoModelForCausalLM.from_pretrained(
@@ -614,8 +638,7 @@ def initialize_model(args, logger):
614638
"token": args.token,
615639
"trust_remote_code": args.trust_remote_code,
616640
}
617-
618-
if args.load_quantized_model:
641+
if args.load_quantized_model_with_inc or args.quantized_inc_model_path:
619642
model_kwargs["torch_dtype"] = torch.bfloat16
620643

621644
if args.trust_remote_code:

tests/test_text_generation_example.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@
6565
("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, 1147.50),
6666
("microsoft/phi-2", 1, 1, True, 128, 128, 254.08932787178165),
6767
],
68+
"load_quantized_model_with_autogptq": [
69+
("TheBloke/Llama-2-7b-Chat-GPTQ", 1, 10, False, 128, 2048, 456.7),
70+
],
6871
"deepspeed": [
6972
("bigscience/bloomz", 8, 1, 36.77314954096159),
7073
("meta-llama/Llama-2-70b-hf", 8, 1, 64.10514998902435),
@@ -108,6 +111,7 @@
108111
("state-spaces/mamba-130m-hf", 224, False, 794.542),
109112
],
110113
"fp8": [],
114+
"load_quantized_model_with_autogptq": [],
111115
"deepspeed": [
112116
("bigscience/bloomz-7b1", 8, 1, 31.994268212011505),
113117
],
@@ -130,6 +134,7 @@ def _test_text_generation(
130134
world_size: int = 8,
131135
torch_compile: bool = False,
132136
fp8: bool = False,
137+
load_quantized_model_with_autogptq: bool = False,
133138
max_input_tokens: int = 0,
134139
max_output_tokens: int = 100,
135140
parallel_strategy: str = None,
@@ -241,6 +246,8 @@ def _test_text_generation(
241246
f"--max_input_tokens {max_input_tokens}",
242247
"--limit_hpu_graphs",
243248
]
249+
if load_quantized_model_with_autogptq:
250+
command += ["--load_quantized_model_with_autogptq"]
244251
if parallel_strategy is not None:
245252
command += [
246253
f"--parallel_strategy={parallel_strategy}",
@@ -334,6 +341,36 @@ def test_text_generation_fp8(
334341
)
335342

336343

344+
@pytest.mark.parametrize(
345+
"model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline",
346+
MODELS_TO_TEST["load_quantized_model_with_autogptq"],
347+
)
348+
def test_text_generation_gptq(
349+
model_name: str,
350+
baseline: float,
351+
world_size: int,
352+
batch_size: int,
353+
reuse_cache: bool,
354+
input_len: int,
355+
output_len: int,
356+
token: str,
357+
):
358+
deepspeed = True if world_size > 1 else False
359+
_test_text_generation(
360+
model_name,
361+
baseline,
362+
token,
363+
deepspeed=deepspeed,
364+
world_size=world_size,
365+
fp8=False,
366+
load_quantized_model_with_autogptq=True,
367+
batch_size=batch_size,
368+
reuse_cache=reuse_cache,
369+
max_input_tokens=input_len,
370+
max_output_tokens=output_len,
371+
)
372+
373+
337374
@pytest.mark.parametrize("model_name, world_size, batch_size, baseline", MODELS_TO_TEST["deepspeed"])
338375
def test_text_generation_deepspeed(model_name: str, baseline: float, world_size: int, batch_size: int, token: str):
339376
_test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, batch_size=batch_size)

0 commit comments

Comments
 (0)