Skip to content

Commit a767f46

Browse files
janeklnasretdinovr
authored andcommitted
PTQ model support, quant_cfg, and documentation updates (NVIDIA-NeMo#13519)
* Enable changing quant_cfg via command line in the ptq.py script Signed-off-by: Jan Lasek <[email protected]> * Add config to HFDeepSeekExporter Signed-off-by: Jan Lasek <[email protected]> * Use MODEL_NAME_TO_TYPE directly from model_utils Signed-off-by: Jan Lasek <[email protected]> * Utility to detect HF DeepSeek version Signed-off-by: Jan Lasek <[email protected]> * Apply isort and black reformatting Signed-off-by: janekl <[email protected]> * Update broken links Signed-off-by: Jan Lasek <[email protected]> * Rename temp var Signed-off-by: Jan Lasek <[email protected]> * Give up quant_cfg_overrides param Signed-off-by: Jan Lasek <[email protected]> * Get model_type for HFAutoModelForCausalLM with mte Signed-off-by: Jan Lasek <[email protected]> * Utils to load quant_cfg from JSON file Signed-off-by: Jan Lasek <[email protected]> * Apply isort and black reformatting Signed-off-by: janekl <[email protected]> * Update quantization.rst links Signed-off-by: Jan Lasek <[email protected]> * Improve load_quant_cfg plus unit test Signed-off-by: Jan Lasek <[email protected]> * Enable loading quant_cfg from a JSON file Signed-off-by: Jan Lasek <[email protected]> * Bugfix for AWQ methods Signed-off-by: Jan Lasek <[email protected]> * Apply isort and black reformatting Signed-off-by: janekl <[email protected]> * Add copyright headers Signed-off-by: Jan Lasek <[email protected]> * Format error msg Signed-off-by: Jan Lasek <[email protected]> --------- Signed-off-by: Jan Lasek <[email protected]> Signed-off-by: janekl <[email protected]> Co-authored-by: janekl <[email protected]>
1 parent d8f1f6d commit a767f46

File tree

8 files changed

+215
-81
lines changed

8 files changed

+215
-81
lines changed

docs/links_needing_review.json

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,6 @@
7878
"uri": "https://cocodataset.org/#download",
7979
"info": "Anchor 'download' not found"
8080
}
81-
{
82-
"filename": "nlp/quantization.rst",
83-
"lineno": 155,
84-
"status": "broken",
85-
"code": 0,
86-
"uri": "https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html",
87-
"info": "404 Client Error: Not Found for url: https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html"
88-
}
8981
{
9082
"filename": "multimodal/mllm/intro.rst",
9183
"lineno": 4,
@@ -126,14 +118,6 @@
126118
"uri": "https://github.com/NVIDIA/NeMo#installation",
127119
"info": "Anchor 'installation' not found"
128120
}
129-
{
130-
"filename": "nlp/quantization.rst",
131-
"lineno": 118,
132-
"status": "broken",
133-
"code": 0,
134-
"uri": "https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization",
135-
"info": "Anchor 'fp8-post-training-quantization' not found"
136-
}
137121
{
138122
"filename": "multimodal/text2img/insp2p.rst",
139123
"lineno": 16,

docs/source/nlp/quantization.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM``
115115
)
116116
trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])
117117
118-
Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization>`_:
118+
Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation <https://nvidia.github.io/TensorRT-LLM/latest/architecture/checkpoint.html#build-checkpoint-into-tensorrt-engine>`_:
119119

120120
.. code-block:: bash
121121
@@ -152,7 +152,7 @@ The example below shows how to perform PTQ and QAT on a Supervised Finetuned Lla
152152
The script is tested using tensor parallelism of 8 on 8x RTX 6000 Ada 48GB GPUs. Alternatively, a single DGX A100 node with 8x 40GB GPUs can be used for the same purpose.
153153
For bigger models like Llama2 70B, you may need to use one or more DGX H100 nodes with 8x 80GB GPUs each.
154154

155-
The example is a modified version of the `SFT with Llama 2 playbook <https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html>`_.
155+
The example is a modified version of the `SFT with Llama 2 playbook <https://docs.nvidia.com/nemo-framework/user-guide/24.07/playbooks/llama2sft.html>`_.
156156
Please refer to the playbook for more details on setting up a BF16 NeMo model and the ``databricks-dolly-15k`` instruction dataset.
157157

158158
First we will run the SFT example command from the playbook as-is to train a Llama2 7B SFT model for 100 steps.

nemo/collections/llm/gpt/model/deepseek.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import json
1515
import re
16-
from dataclasses import dataclass, field
16+
from dataclasses import asdict, dataclass, field
1717
from functools import cached_property, partial
1818
from pathlib import Path
1919
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@@ -45,6 +45,7 @@
4545
if TYPE_CHECKING:
4646
from megatron.core.transformer import ModuleSpec
4747
from transformers import AutoModelForCausalLM
48+
from transformers import DeepseekV3Config as HFDeepseekV3Config
4849

4950
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
5051
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
@@ -472,6 +473,28 @@ def init(self, dtype=torch.bfloat16, model_name="deepseek-ai/DeepSeek-V3") -> "A
472473
type(hf_model).register_for_auto_class("AutoModelForCausalLM")
473474
return hf_model
474475

476+
def _detect_hf_deepseek_version(self, source_config: Dict[str, Any]) -> str:
477+
"""
478+
Detect the HF DeepSeek version based on the source NeMo config.
479+
480+
Args:
481+
source_config (Dict[str, Any]): The source NeMo model config.
482+
483+
Returns:
484+
str: The DeepSeek version in the Hugging Face Hub convention.
485+
"""
486+
if source_config['moe_router_enable_expert_bias']:
487+
target_model_name = "deepseek-ai/DeepSeek-V3"
488+
elif source_config['q_lora_rank'] is not None:
489+
target_model_name = "deepseek-ai/DeepSeek-V2"
490+
else:
491+
target_model_name = "deepseek-ai/DeepSeek-V2-Lite"
492+
logging.info(
493+
f"Your model is determined to be {target_model_name} based on the config. If this is not correct, "
494+
f"please pass in a local HF checkpoint."
495+
)
496+
return target_model_name
497+
475498
def ckpt_load(self, path: Path) -> Tuple[Dict, Dict]:
476499
"""
477500
This function loads the state dict directly from a distributed checkpoint, and modify the state dict
@@ -511,21 +534,12 @@ def apply(self, output_path: Path, target_model_name=None) -> Path:
511534
logging.info("DeepSeek NeMo checkpoint loaded.")
512535
if target_model_name is None:
513536
# Before DeepSeek is fully supported by HF, it is necessary to pass in a local HF checkpoint that
514-
# is used to initialize the HF model. The following
537+
# is used to initialize the HF model.
515538
logging.warning(
516539
"Before DeepSeek is officially supported in HF, you should pass in a local HF "
517540
"checkpoint using llm.export_ckpt(..., target_model_name=<local hf path>)"
518541
)
519-
if source_config['moe_router_enable_expert_bias']:
520-
target_model_name = "deepseek-ai/DeepSeek-V3"
521-
elif source_config['q_lora_rank'] is not None:
522-
target_model_name = "deepseek-ai/DeepSeek-V2"
523-
else:
524-
target_model_name = "deepseek-ai/DeepSeek-V2-Lite"
525-
logging.info(
526-
f"Your model is determined to be {target_model_name} based on the config. If this is not correct, "
527-
f"please pass in a local HF checkpoint."
528-
)
542+
target_model_name = self._detect_hf_deepseek_version(source_config)
529543

530544
target = self.init(torch_dtype_from_dict_config(source_config), model_name=target_model_name)
531545
target = self.convert_state(source, target, source_config)
@@ -639,6 +653,60 @@ def _modify_source_state(self, source: Dict[str, Any], source_config: Dict[str,
639653
def tokenizer(self) -> 'AutoTokenizer':
640654
return io.load_context(self, subpath="model").tokenizer
641655

656+
@property
657+
def config(self) -> "HFDeepseekV3Config":
658+
"""Create a HF DeepseekV3Config from the NeMo model config.
659+
660+
Translates the NeMo configuration parameters to the equivalent HF
661+
configuration.
662+
663+
Currently only supports DeepseekV3Config based on availability
664+
in the Transformers library.
665+
666+
Returns:
667+
HFDeepseekV3Config: HF configuration for DeepSeekV3 models
668+
"""
669+
# TODO: Get config for all DeepSeek model variants once available in transformers
670+
671+
from transformers import DeepseekV3Config as HFDeepseekV3Config
672+
673+
source: DeepSeekV3Config = io.load_context(str(self)).model.config
674+
675+
target_model_name = self._detect_hf_deepseek_version(asdict(source))
676+
if target_model_name != "deepseek-ai/DeepSeek-V3":
677+
raise ValueError(f"Getting config for model other than {target_model_name} is not supported.")
678+
679+
# Figure out the number of zeros in the prefix of moe_layer_freq array
680+
# for the HF first_k_dense_replace parameter and validate the reminder:
681+
k = 0
682+
while k < len(source.moe_layer_freq) and source.moe_layer_freq[k] == 0:
683+
k += 1
684+
assert all(x == 1 for x in source.moe_layer_freq[k:])
685+
686+
return HFDeepseekV3Config(
687+
architectures=["DeepseekV3ForCausalLM"],
688+
num_hidden_layers=source.num_layers,
689+
hidden_size=source.hidden_size,
690+
intermediate_size=source.ffn_hidden_size,
691+
num_attention_heads=source.num_attention_heads,
692+
q_lora_rank=source.q_lora_rank,
693+
qk_nope_head_dim=source.qk_head_dim,
694+
qk_rope_head_dim=source.qk_pos_emb_head_dim,
695+
v_head_dim=source.v_head_dim,
696+
kv_lora_rank=source.kv_lora_rank,
697+
num_key_value_heads=source.kv_channels,
698+
n_routed_experts=source.num_moe_experts,
699+
moe_intermediate_size=source.moe_ffn_hidden_size,
700+
first_k_dense_replace=k,
701+
num_experts_per_tok=source.moe_router_topk,
702+
n_group=source.moe_router_num_groups,
703+
topk_group=source.moe_router_group_topk,
704+
routed_scaling_factor=source.moe_router_topk_scaling_factor,
705+
aux_loss_alpha=source.moe_aux_loss_coeff,
706+
max_position_embeddings=source.max_position_embeddings,
707+
vocab_size=self.tokenizer.vocab_size,
708+
)
709+
642710

643711
__all__ = [
644712
"DeepSeekConfig",

nemo/collections/llm/modelopt/quantization/quantizer.py

Lines changed: 15 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import copy
1616
import os
17+
import pprint
1718
import shutil
1819
import tempfile
1920
from dataclasses import dataclass
@@ -29,6 +30,7 @@
2930
from nemo.collections import llm
3031
from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate
3132
from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import get_quant_cfg_choices
33+
from nemo.collections.llm.modelopt.quantization.utils import load_quant_cfg
3234
from nemo.collections.llm.utils import barrier, torch_dtype_from_precision
3335
from nemo.lightning import io
3436
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
@@ -121,11 +123,8 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor
121123

122124
self.quantization_config = quantization_config
123125
self.export_config = export_config
124-
125-
algorithm = quantization_config.algorithm
126126
dtype = export_config.dtype
127127
# Export and Quantization config sanity checks
128-
assert algorithm is None or algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization algorithm: {algorithm}"
129128
if quantization_config.enable_kv_cache:
130129
assert (
131130
quantization_config.kv_cache_qformat in KV_QUANT_CFG_CHOICES
@@ -235,12 +234,15 @@ def huggingface_forward_loop(model):
235234

236235
def _get_quant_cfg(self, model):
237236
decoder_type = self._get_decoder_type(model, optional=True)
238-
assert (
239-
self.quantization_config.algorithm in QUANT_CFG_CHOICES
240-
), f"Unsupported quantization format: {self.quantization_config.algorithm}"
237+
algorithm = self.quantization_config.algorithm
238+
239+
if os.path.isfile(algorithm):
240+
return load_quant_cfg(algorithm)
241241

242-
quant_cfg = QUANT_CFG_CHOICES[self.quantization_config.algorithm]
243-
if "awq" in self.quantization_config.algorithm:
242+
assert algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization format: {algorithm}"
243+
244+
quant_cfg = QUANT_CFG_CHOICES[algorithm]
245+
if "awq" in algorithm:
244246
quant_cfg = copy.deepcopy(quant_cfg)
245247
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
246248
if isinstance(weight_quantizer, list):
@@ -250,11 +252,11 @@ def _get_quant_cfg(self, model):
250252
weight_quantizer["block_sizes"][-1] = self.quantization_config.awq_block_size
251253

252254
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
253-
if "w4a8_awq" == self.quantization_config.algorithm and decoder_type in ["gemma", "mpt"]:
255+
if "w4a8_awq" == algorithm and decoder_type in ["gemma", "mpt"]:
254256
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
255257

256258
if self.quantization_config.enable_kv_cache is None:
257-
enable_quant_kv_cache = "int8" not in self.quantization_config.algorithm and decoder_type != "gpt"
259+
enable_quant_kv_cache = "int8" not in algorithm and decoder_type != "gpt"
258260
else:
259261
enable_quant_kv_cache = self.quantization_config.enable_kv_cache
260262
if self.quantization_config.enable_kv_cache is None and enable_quant_kv_cache:
@@ -276,7 +278,7 @@ def _get_quant_cfg(self, model):
276278
quant_cfg["algorithm"] = "max"
277279

278280
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
279-
if decoder_type == "gemma" and "int8_sq" in self.quantization_config.algorithm:
281+
if decoder_type == "gemma" and "int8_sq" in algorithm:
280282
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
281283

282284
return quant_cfg
@@ -299,6 +301,7 @@ def quantize(self, model: "MegatronParallel", forward_loop=None):
299301
self._setup(model)
300302
decoder_type = self._get_decoder_type(model, optional=True)
301303
quant_cfg = self._get_quant_cfg(model)
304+
logging.info(f"Using quant_cfg:\n{pprint.pformat(quant_cfg)}")
302305
unwrapped_model = mtq.quantize(unwrap_for_modelopt_operations(model), quant_cfg, forward_loop)
303306
if decoder_type == "gpt":
304307
# We found squared_relu may have an under-calibration problem.
@@ -505,42 +508,6 @@ def _get_iterator():
505508
return _get_iterator
506509

507510

508-
huggingface_model_type_pattern_match = {
509-
"GPT2": "gpt",
510-
"Mllama": "mllama",
511-
"Llama": "llama",
512-
"Mistral": "llama",
513-
"GPTJ": "gptj",
514-
"FalconForCausalLM": "falcon",
515-
"RWForCausalLM": "falcon",
516-
"baichuan": "baichuan",
517-
"MPT": "mpt",
518-
"Bloom": "bloom",
519-
"ChatGLM": "chatglm",
520-
"QWen": "qwen",
521-
"RecurrentGemma": "recurrentgemma",
522-
"Gemma2": "gemma2",
523-
"Gemma3": "gemma3",
524-
"Gemma": "gemma",
525-
"phi3small": "phi3small",
526-
"phi3": "phi3",
527-
"PhiMoEForCausalLM": "phi3",
528-
"phi": "phi",
529-
"TLGv4ForCausalLM": "phi",
530-
"MixtralForCausalLM": "llama",
531-
"ArcticForCausalLM": "llama",
532-
"StarCoder": "gpt",
533-
"Dbrx": "dbrx",
534-
"T5": "t5",
535-
"Bart": "bart",
536-
"GLM": "glm",
537-
"InternLM2ForCausalLM": "internlm",
538-
"ExaoneForCausalLM": "exaone",
539-
"Nemotron": "gpt",
540-
"Deepseek": "deepseek",
541-
"Whisper": "whisper",
542-
}
543-
544511
gpt_model_type = [
545512
(llm.Baichuan2Model, "baichuan"),
546513
(llm.ChatGLMModel, "chatglm"),
@@ -576,9 +543,7 @@ def get_modelopt_decoder_type(model: Union[llm.GPTModel, llm.HFAutoModelForCausa
576543
Optional[str]: The inferred decoder type or None if no match is found.
577544
"""
578545
if isinstance(model, llm.HFAutoModelForCausalLM):
579-
for k, v in huggingface_model_type_pattern_match.items():
580-
if k.lower() in type(model.model).__name__.lower():
581-
return v
546+
return mte.model_utils.get_model_type(model.model)
582547
else:
583548
for config_class, decoder_type in gpt_model_type:
584549
if isinstance(model, config_class):
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from pathlib import Path
17+
from typing import Any, Dict, Union
18+
19+
from nemo.utils.cast_utils import maybe_cast_to_type
20+
21+
22+
def standardize_json_config(quant_cfg: Dict[str, Any]):
23+
"""Standardize the quantization configuration loaded from a JSON file to
24+
ensure compatibility with modelopt. Modifiy the input dictionary in place.
25+
26+
Args:
27+
quant_cfg (Dict[str, Any]): The quantization config dictionary to standardize.
28+
"""
29+
for key, value in quant_cfg.items():
30+
if key == "block_sizes":
31+
value = {maybe_cast_to_type(k, int): v for k, v in value.items()}
32+
quant_cfg[key] = value
33+
elif key in {"num_bits", "scale_bits"} and isinstance(value, list):
34+
quant_cfg[key] = tuple(value)
35+
continue # No further processing needed
36+
if isinstance(value, dict):
37+
standardize_json_config(value)
38+
elif isinstance(value, list):
39+
for x in value:
40+
if isinstance(x, dict):
41+
standardize_json_config(x)
42+
43+
44+
def load_quant_cfg(cfg_path: Union[str, Path]) -> Dict[str, Any]:
45+
"""Load quantization configuration from a JSON file and adjust for
46+
modelopt standards if necessary.
47+
48+
Args:
49+
cfg_path (str): Path to the quantization config JSON file.
50+
51+
Returns:
52+
dict: The loaded quantization configuration.
53+
"""
54+
with open(cfg_path, "r") as f:
55+
quant_cfg = json.load(f)
56+
57+
standardize_json_config(quant_cfg)
58+
return quant_cfg

nemo/utils/cast_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from contextlib import contextmanager, nullcontext
16+
from typing import Any
1617

1718
import torch
1819

@@ -100,3 +101,19 @@ def monkeypatched(object, name, patch):
100101
setattr(object, name, patch)
101102
yield object
102103
setattr(object, name, pre_patched_value)
104+
105+
106+
def maybe_cast_to_type(x: Any, type_: type) -> Any:
107+
"""Try to cast a value to int, if it fails, return the original value.
108+
109+
Args:
110+
x (Any): The value to be casted.
111+
type_ (type): The type to cast to, must be a callable.
112+
113+
Returns:
114+
Any: The casted value or the original value if casting fails.
115+
"""
116+
try:
117+
return type_(x)
118+
except Exception:
119+
return x

0 commit comments

Comments
 (0)