diff --git a/docs/links_needing_review.json b/docs/links_needing_review.json index 9927badf1130..47f20feaf91e 100644 --- a/docs/links_needing_review.json +++ b/docs/links_needing_review.json @@ -78,14 +78,6 @@ "uri": "https://cocodataset.org/#download", "info": "Anchor 'download' not found" } -{ - "filename": "nlp/quantization.rst", - "lineno": 155, - "status": "broken", - "code": 0, - "uri": "https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html", - "info": "404 Client Error: Not Found for url: https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html" -} { "filename": "multimodal/mllm/intro.rst", "lineno": 4, @@ -126,14 +118,6 @@ "uri": "https://github.com/NVIDIA/NeMo#installation", "info": "Anchor 'installation' not found" } -{ - "filename": "nlp/quantization.rst", - "lineno": 118, - "status": "broken", - "code": 0, - "uri": "https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization", - "info": "Anchor 'fp8-post-training-quantization' not found" -} { "filename": "multimodal/text2img/insp2p.rst", "lineno": 16, diff --git a/docs/source/nlp/quantization.rst b/docs/source/nlp/quantization.rst index 51da96240fe9..85838f4e0188 100644 --- a/docs/source/nlp/quantization.rst +++ b/docs/source/nlp/quantization.rst @@ -115,7 +115,7 @@ The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM`` ) trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"]) -Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation `_: +Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation `_: .. code-block:: bash @@ -152,7 +152,7 @@ The example below shows how to perform PTQ and QAT on a Supervised Finetuned Lla The script is tested using tensor parallelism of 8 on 8x RTX 6000 Ada 48GB GPUs. Alternatively, a single DGX A100 node with 8x 40GB GPUs can be used for the same purpose. For bigger models like Llama2 70B, you may need to use one or more DGX H100 nodes with 8x 80GB GPUs each. -The example is a modified version of the `SFT with Llama 2 playbook `_. +The example is a modified version of the `SFT with Llama 2 playbook `_. Please refer to the playbook for more details on setting up a BF16 NeMo model and the ``databricks-dolly-15k`` instruction dataset. First we will run the SFT example command from the playbook as-is to train a Llama2 7B SFT model for 100 steps. diff --git a/nemo/collections/llm/gpt/model/deepseek.py b/nemo/collections/llm/gpt/model/deepseek.py index 0f12add24da9..003535d0fc30 100644 --- a/nemo/collections/llm/gpt/model/deepseek.py +++ b/nemo/collections/llm/gpt/model/deepseek.py @@ -13,7 +13,7 @@ # limitations under the License. import json import re -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import cached_property, partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union @@ -45,6 +45,7 @@ if TYPE_CHECKING: from megatron.core.transformer import ModuleSpec from transformers import AutoModelForCausalLM + from transformers import DeepseekV3Config as HFDeepseekV3Config from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -472,6 +473,28 @@ def init(self, dtype=torch.bfloat16, model_name="deepseek-ai/DeepSeek-V3") -> "A type(hf_model).register_for_auto_class("AutoModelForCausalLM") return hf_model + def _detect_hf_deepseek_version(self, source_config: Dict[str, Any]) -> str: + """ + Detect the HF DeepSeek version based on the source NeMo config. + + Args: + source_config (Dict[str, Any]): The source NeMo model config. + + Returns: + str: The DeepSeek version in the Hugging Face Hub convention. + """ + if source_config['moe_router_enable_expert_bias']: + target_model_name = "deepseek-ai/DeepSeek-V3" + elif source_config['q_lora_rank'] is not None: + target_model_name = "deepseek-ai/DeepSeek-V2" + else: + target_model_name = "deepseek-ai/DeepSeek-V2-Lite" + logging.info( + f"Your model is determined to be {target_model_name} based on the config. If this is not correct, " + f"please pass in a local HF checkpoint." + ) + return target_model_name + def ckpt_load(self, path: Path) -> Tuple[Dict, Dict]: """ This function loads the state dict directly from a distributed checkpoint, and modify the state dict @@ -511,21 +534,12 @@ def apply(self, output_path: Path, target_model_name=None) -> Path: logging.info("DeepSeek NeMo checkpoint loaded.") if target_model_name is None: # Before DeepSeek is fully supported by HF, it is necessary to pass in a local HF checkpoint that - # is used to initialize the HF model. The following + # is used to initialize the HF model. logging.warning( "Before DeepSeek is officially supported in HF, you should pass in a local HF " "checkpoint using llm.export_ckpt(..., target_model_name=)" ) - if source_config['moe_router_enable_expert_bias']: - target_model_name = "deepseek-ai/DeepSeek-V3" - elif source_config['q_lora_rank'] is not None: - target_model_name = "deepseek-ai/DeepSeek-V2" - else: - target_model_name = "deepseek-ai/DeepSeek-V2-Lite" - logging.info( - f"Your model is determined to be {target_model_name} based on the config. If this is not correct, " - f"please pass in a local HF checkpoint." - ) + target_model_name = self._detect_hf_deepseek_version(source_config) target = self.init(torch_dtype_from_dict_config(source_config), model_name=target_model_name) target = self.convert_state(source, target, source_config) @@ -639,6 +653,60 @@ def _modify_source_state(self, source: Dict[str, Any], source_config: Dict[str, def tokenizer(self) -> 'AutoTokenizer': return io.load_context(self, subpath="model").tokenizer + @property + def config(self) -> "HFDeepseekV3Config": + """Create a HF DeepseekV3Config from the NeMo model config. + + Translates the NeMo configuration parameters to the equivalent HF + configuration. + + Currently only supports DeepseekV3Config based on availability + in the Transformers library. + + Returns: + HFDeepseekV3Config: HF configuration for DeepSeekV3 models + """ + # TODO: Get config for all DeepSeek model variants once available in transformers + + from transformers import DeepseekV3Config as HFDeepseekV3Config + + source: DeepSeekV3Config = io.load_context(str(self)).model.config + + target_model_name = self._detect_hf_deepseek_version(asdict(source)) + if target_model_name != "deepseek-ai/DeepSeek-V3": + raise ValueError(f"Getting config for model other than {target_model_name} is not supported.") + + # Figure out the number of zeros in the prefix of moe_layer_freq array + # for the HF first_k_dense_replace parameter and validate the reminder: + k = 0 + while k < len(source.moe_layer_freq) and source.moe_layer_freq[k] == 0: + k += 1 + assert all(x == 1 for x in source.moe_layer_freq[k:]) + + return HFDeepseekV3Config( + architectures=["DeepseekV3ForCausalLM"], + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + q_lora_rank=source.q_lora_rank, + qk_nope_head_dim=source.qk_head_dim, + qk_rope_head_dim=source.qk_pos_emb_head_dim, + v_head_dim=source.v_head_dim, + kv_lora_rank=source.kv_lora_rank, + num_key_value_heads=source.kv_channels, + n_routed_experts=source.num_moe_experts, + moe_intermediate_size=source.moe_ffn_hidden_size, + first_k_dense_replace=k, + num_experts_per_tok=source.moe_router_topk, + n_group=source.moe_router_num_groups, + topk_group=source.moe_router_group_topk, + routed_scaling_factor=source.moe_router_topk_scaling_factor, + aux_loss_alpha=source.moe_aux_loss_coeff, + max_position_embeddings=source.max_position_embeddings, + vocab_size=self.tokenizer.vocab_size, + ) + __all__ = [ "DeepSeekConfig", diff --git a/nemo/collections/llm/modelopt/quantization/quantizer.py b/nemo/collections/llm/modelopt/quantization/quantizer.py index 031cf2a92720..10eff9a6339e 100644 --- a/nemo/collections/llm/modelopt/quantization/quantizer.py +++ b/nemo/collections/llm/modelopt/quantization/quantizer.py @@ -14,6 +14,7 @@ import copy import os +import pprint import shutil import tempfile from dataclasses import dataclass @@ -29,6 +30,7 @@ from nemo.collections import llm from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import get_quant_cfg_choices +from nemo.collections.llm.modelopt.quantization.utils import load_quant_cfg from nemo.collections.llm.utils import barrier, torch_dtype_from_precision from nemo.lightning import io from nemo.lightning.ckpt_utils import ckpt_to_context_subdir @@ -121,11 +123,8 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor self.quantization_config = quantization_config self.export_config = export_config - - algorithm = quantization_config.algorithm dtype = export_config.dtype # Export and Quantization config sanity checks - assert algorithm is None or algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization algorithm: {algorithm}" if quantization_config.enable_kv_cache: assert ( quantization_config.kv_cache_qformat in KV_QUANT_CFG_CHOICES @@ -235,12 +234,15 @@ def huggingface_forward_loop(model): def _get_quant_cfg(self, model): decoder_type = self._get_decoder_type(model, optional=True) - assert ( - self.quantization_config.algorithm in QUANT_CFG_CHOICES - ), f"Unsupported quantization format: {self.quantization_config.algorithm}" + algorithm = self.quantization_config.algorithm + + if os.path.isfile(algorithm): + return load_quant_cfg(algorithm) - quant_cfg = QUANT_CFG_CHOICES[self.quantization_config.algorithm] - if "awq" in self.quantization_config.algorithm: + assert algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization format: {algorithm}" + + quant_cfg = QUANT_CFG_CHOICES[algorithm] + if "awq" in algorithm: quant_cfg = copy.deepcopy(quant_cfg) weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): @@ -250,11 +252,11 @@ def _get_quant_cfg(self, model): weight_quantizer["block_sizes"][-1] = self.quantization_config.awq_block_size # Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models - if "w4a8_awq" == self.quantization_config.algorithm and decoder_type in ["gemma", "mpt"]: + if "w4a8_awq" == algorithm and decoder_type in ["gemma", "mpt"]: quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1} if self.quantization_config.enable_kv_cache is None: - enable_quant_kv_cache = "int8" not in self.quantization_config.algorithm and decoder_type != "gpt" + enable_quant_kv_cache = "int8" not in algorithm and decoder_type != "gpt" else: enable_quant_kv_cache = self.quantization_config.enable_kv_cache if self.quantization_config.enable_kv_cache is None and enable_quant_kv_cache: @@ -276,7 +278,7 @@ def _get_quant_cfg(self, model): quant_cfg["algorithm"] = "max" # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. - if decoder_type == "gemma" and "int8_sq" in self.quantization_config.algorithm: + if decoder_type == "gemma" and "int8_sq" in algorithm: quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} return quant_cfg @@ -299,6 +301,7 @@ def quantize(self, model: "MegatronParallel", forward_loop=None): self._setup(model) decoder_type = self._get_decoder_type(model, optional=True) quant_cfg = self._get_quant_cfg(model) + logging.info(f"Using quant_cfg:\n{pprint.pformat(quant_cfg)}") unwrapped_model = mtq.quantize(unwrap_for_modelopt_operations(model), quant_cfg, forward_loop) if decoder_type == "gpt": # We found squared_relu may have an under-calibration problem. @@ -505,42 +508,6 @@ def _get_iterator(): return _get_iterator -huggingface_model_type_pattern_match = { - "GPT2": "gpt", - "Mllama": "mllama", - "Llama": "llama", - "Mistral": "llama", - "GPTJ": "gptj", - "FalconForCausalLM": "falcon", - "RWForCausalLM": "falcon", - "baichuan": "baichuan", - "MPT": "mpt", - "Bloom": "bloom", - "ChatGLM": "chatglm", - "QWen": "qwen", - "RecurrentGemma": "recurrentgemma", - "Gemma2": "gemma2", - "Gemma3": "gemma3", - "Gemma": "gemma", - "phi3small": "phi3small", - "phi3": "phi3", - "PhiMoEForCausalLM": "phi3", - "phi": "phi", - "TLGv4ForCausalLM": "phi", - "MixtralForCausalLM": "llama", - "ArcticForCausalLM": "llama", - "StarCoder": "gpt", - "Dbrx": "dbrx", - "T5": "t5", - "Bart": "bart", - "GLM": "glm", - "InternLM2ForCausalLM": "internlm", - "ExaoneForCausalLM": "exaone", - "Nemotron": "gpt", - "Deepseek": "deepseek", - "Whisper": "whisper", -} - gpt_model_type = [ (llm.Baichuan2Model, "baichuan"), (llm.ChatGLMModel, "chatglm"), @@ -576,9 +543,7 @@ def get_modelopt_decoder_type(model: Union[llm.GPTModel, llm.HFAutoModelForCausa Optional[str]: The inferred decoder type or None if no match is found. """ if isinstance(model, llm.HFAutoModelForCausalLM): - for k, v in huggingface_model_type_pattern_match.items(): - if k.lower() in type(model.model).__name__.lower(): - return v + return mte.model_utils.get_model_type(model.model) else: for config_class, decoder_type in gpt_model_type: if isinstance(model, config_class): diff --git a/nemo/collections/llm/modelopt/quantization/utils.py b/nemo/collections/llm/modelopt/quantization/utils.py new file mode 100644 index 000000000000..b8d283b5e2d3 --- /dev/null +++ b/nemo/collections/llm/modelopt/quantization/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import Any, Dict, Union + +from nemo.utils.cast_utils import maybe_cast_to_type + + +def standardize_json_config(quant_cfg: Dict[str, Any]): + """Standardize the quantization configuration loaded from a JSON file to + ensure compatibility with modelopt. Modifiy the input dictionary in place. + + Args: + quant_cfg (Dict[str, Any]): The quantization config dictionary to standardize. + """ + for key, value in quant_cfg.items(): + if key == "block_sizes": + value = {maybe_cast_to_type(k, int): v for k, v in value.items()} + quant_cfg[key] = value + elif key in {"num_bits", "scale_bits"} and isinstance(value, list): + quant_cfg[key] = tuple(value) + continue # No further processing needed + if isinstance(value, dict): + standardize_json_config(value) + elif isinstance(value, list): + for x in value: + if isinstance(x, dict): + standardize_json_config(x) + + +def load_quant_cfg(cfg_path: Union[str, Path]) -> Dict[str, Any]: + """Load quantization configuration from a JSON file and adjust for + modelopt standards if necessary. + + Args: + cfg_path (str): Path to the quantization config JSON file. + + Returns: + dict: The loaded quantization configuration. + """ + with open(cfg_path, "r") as f: + quant_cfg = json.load(f) + + standardize_json_config(quant_cfg) + return quant_cfg diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 72d6c5c496d9..472bd6664e6d 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import contextmanager, nullcontext +from typing import Any import torch @@ -100,3 +101,19 @@ def monkeypatched(object, name, patch): setattr(object, name, patch) yield object setattr(object, name, pre_patched_value) + + +def maybe_cast_to_type(x: Any, type_: type) -> Any: + """Try to cast a value to int, if it fails, return the original value. + + Args: + x (Any): The value to be casted. + type_ (type): The type to cast to, must be a callable. + + Returns: + Any: The casted value or the original value if casting fails. + """ + try: + return type_(x) + except Exception: + return x diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index 336d72c4f958..8fc15dfc9dbf 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import os from nemo.collections import llm from nemo.collections.llm.modelopt import ExportConfig, QuantizationConfig @@ -72,7 +73,6 @@ def get_args(): "--algorithm", type=str, default="fp8", - choices=QUANT_CFG_CHOICES_LIST, help="TensorRT-Model-Optimizer quantization algorithm", ) parser.add_argument( @@ -113,6 +113,12 @@ def get_args(): parser.add_argument("--legacy_ckpt", help="Load ckpt saved with TE < 1.14", action="store_true") args = parser.parse_args() + if args.algorithm not in QUANT_CFG_CHOICES_LIST and not os.path.isfile(args.algorithm): + raise ValueError( + f"Quantization algorithm {args.algorithm} is not supported: choose one of {QUANT_CFG_CHOICES_LIST} " + "or provide a path to a JSON file with a quantization configuration." + ) + if args.export_path is None: if args.export_format == "trtllm": args.export_path = f"./qnemo_{args.algorithm}_tp{args.inference_tp}_pp{args.inference_pp}" diff --git a/tests/collections/llm/modelopt/quantization/test_load_quant_cfg.py b/tests/collections/llm/modelopt/quantization/test_load_quant_cfg.py new file mode 100644 index 000000000000..0028ddd41406 --- /dev/null +++ b/tests/collections/llm/modelopt/quantization/test_load_quant_cfg.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile + +import pytest + +from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import get_quant_cfg_choices +from nemo.collections.llm.modelopt.quantization.utils import load_quant_cfg + +QUANT_CFG_CHOICES = get_quant_cfg_choices() + + +@pytest.mark.parametrize("cfg_name", ["nvfp4", "fp8"]) +def test_load_quant_cfg(cfg_name): + """Test loading a quantization config from a JSON file.""" + + quant_cfg_org = QUANT_CFG_CHOICES[cfg_name] + + with tempfile.NamedTemporaryFile(mode="w") as temp_file: + json.dump(quant_cfg_org, temp_file) + temp_file.flush() + quant_cfg_loaded = load_quant_cfg(temp_file.name) + assert quant_cfg_loaded == quant_cfg_org