Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions docs/links_needing_review.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@
"uri": "https://cocodataset.org/#download",
"info": "Anchor 'download' not found"
}
{
"filename": "nlp/quantization.rst",
"lineno": 155,
"status": "broken",
"code": 0,
"uri": "https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html",
"info": "404 Client Error: Not Found for url: https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html"
}
{
"filename": "multimodal/mllm/intro.rst",
"lineno": 4,
Expand Down Expand Up @@ -126,14 +118,6 @@
"uri": "https://github.com/NVIDIA/NeMo#installation",
"info": "Anchor 'installation' not found"
}
{
"filename": "nlp/quantization.rst",
"lineno": 118,
"status": "broken",
"code": 0,
"uri": "https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization",
"info": "Anchor 'fp8-post-training-quantization' not found"
}
{
"filename": "multimodal/text2img/insp2p.rst",
"lineno": 16,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/nlp/quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM``
)
trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])

Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization>`_:
Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation <https://nvidia.github.io/TensorRT-LLM/latest/architecture/checkpoint.html#build-checkpoint-into-tensorrt-engine>`_:

.. code-block:: bash

Expand Down Expand Up @@ -152,7 +152,7 @@ The example below shows how to perform PTQ and QAT on a Supervised Finetuned Lla
The script is tested using tensor parallelism of 8 on 8x RTX 6000 Ada 48GB GPUs. Alternatively, a single DGX A100 node with 8x 40GB GPUs can be used for the same purpose.
For bigger models like Llama2 70B, you may need to use one or more DGX H100 nodes with 8x 80GB GPUs each.

The example is a modified version of the `SFT with Llama 2 playbook <https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/llama2sft.html>`_.
The example is a modified version of the `SFT with Llama 2 playbook <https://docs.nvidia.com/nemo-framework/user-guide/24.07/playbooks/llama2sft.html>`_.
Please refer to the playbook for more details on setting up a BF16 NeMo model and the ``databricks-dolly-15k`` instruction dataset.

First we will run the SFT example command from the playbook as-is to train a Llama2 7B SFT model for 100 steps.
Expand Down
92 changes: 80 additions & 12 deletions nemo/collections/llm/gpt/model/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import json
import re
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from functools import cached_property, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -45,6 +45,7 @@
if TYPE_CHECKING:
from megatron.core.transformer import ModuleSpec
from transformers import AutoModelForCausalLM
from transformers import DeepseekV3Config as HFDeepseekV3Config

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
Expand Down Expand Up @@ -472,6 +473,28 @@ def init(self, dtype=torch.bfloat16, model_name="deepseek-ai/DeepSeek-V3") -> "A
type(hf_model).register_for_auto_class("AutoModelForCausalLM")
return hf_model

def _detect_hf_deepseek_version(self, source_config: Dict[str, Any]) -> str:
"""
Detect the HF DeepSeek version based on the source NeMo config.

Args:
source_config (Dict[str, Any]): The source NeMo model config.

Returns:
str: The DeepSeek version in the Hugging Face Hub convention.
"""
if source_config['moe_router_enable_expert_bias']:
target_model_name = "deepseek-ai/DeepSeek-V3"
elif source_config['q_lora_rank'] is not None:
target_model_name = "deepseek-ai/DeepSeek-V2"
else:
target_model_name = "deepseek-ai/DeepSeek-V2-Lite"
logging.info(
f"Your model is determined to be {target_model_name} based on the config. If this is not correct, "
f"please pass in a local HF checkpoint."
)
return target_model_name

def ckpt_load(self, path: Path) -> Tuple[Dict, Dict]:
"""
This function loads the state dict directly from a distributed checkpoint, and modify the state dict
Expand Down Expand Up @@ -511,21 +534,12 @@ def apply(self, output_path: Path, target_model_name=None) -> Path:
logging.info("DeepSeek NeMo checkpoint loaded.")
if target_model_name is None:
# Before DeepSeek is fully supported by HF, it is necessary to pass in a local HF checkpoint that
# is used to initialize the HF model. The following
# is used to initialize the HF model.
logging.warning(
"Before DeepSeek is officially supported in HF, you should pass in a local HF "
"checkpoint using llm.export_ckpt(..., target_model_name=<local hf path>)"
)
if source_config['moe_router_enable_expert_bias']:
target_model_name = "deepseek-ai/DeepSeek-V3"
elif source_config['q_lora_rank'] is not None:
target_model_name = "deepseek-ai/DeepSeek-V2"
else:
target_model_name = "deepseek-ai/DeepSeek-V2-Lite"
logging.info(
f"Your model is determined to be {target_model_name} based on the config. If this is not correct, "
f"please pass in a local HF checkpoint."
)
target_model_name = self._detect_hf_deepseek_version(source_config)

target = self.init(torch_dtype_from_dict_config(source_config), model_name=target_model_name)
target = self.convert_state(source, target, source_config)
Expand Down Expand Up @@ -639,6 +653,60 @@ def _modify_source_state(self, source: Dict[str, Any], source_config: Dict[str,
def tokenizer(self) -> 'AutoTokenizer':
return io.load_context(self, subpath="model").tokenizer

@property
def config(self) -> "HFDeepseekV3Config":
"""Create a HF DeepseekV3Config from the NeMo model config.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does HFDeepSeekExporter only export V3 model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today only DeepseekV3Config is importable from transformers==4.52.3 so adding other models requires some extra work. I suggest we add them later once they can be imported from transformers to avoid workarounds. This can be then easily extended


Translates the NeMo configuration parameters to the equivalent HF
configuration.

Currently only supports DeepseekV3Config based on availability
in the Transformers library.

Returns:
HFDeepseekV3Config: HF configuration for DeepSeekV3 models
"""
# TODO: Get config for all DeepSeek model variants once available in transformers

from transformers import DeepseekV3Config as HFDeepseekV3Config

source: DeepSeekV3Config = io.load_context(str(self)).model.config

target_model_name = self._detect_hf_deepseek_version(asdict(source))
if target_model_name != "deepseek-ai/DeepSeek-V3":
raise ValueError(f"Getting config for model other than {target_model_name} is not supported.")

# Figure out the number of zeros in the prefix of moe_layer_freq array
# for the HF first_k_dense_replace parameter and validate the reminder:
k = 0
while k < len(source.moe_layer_freq) and source.moe_layer_freq[k] == 0:
k += 1
assert all(x == 1 for x in source.moe_layer_freq[k:])

return HFDeepseekV3Config(
architectures=["DeepseekV3ForCausalLM"],
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
intermediate_size=source.ffn_hidden_size,
num_attention_heads=source.num_attention_heads,
q_lora_rank=source.q_lora_rank,
qk_nope_head_dim=source.qk_head_dim,
qk_rope_head_dim=source.qk_pos_emb_head_dim,
v_head_dim=source.v_head_dim,
kv_lora_rank=source.kv_lora_rank,
num_key_value_heads=source.kv_channels,
n_routed_experts=source.num_moe_experts,
moe_intermediate_size=source.moe_ffn_hidden_size,
first_k_dense_replace=k,
num_experts_per_tok=source.moe_router_topk,
n_group=source.moe_router_num_groups,
topk_group=source.moe_router_group_topk,
routed_scaling_factor=source.moe_router_topk_scaling_factor,
aux_loss_alpha=source.moe_aux_loss_coeff,
max_position_embeddings=source.max_position_embeddings,
vocab_size=self.tokenizer.vocab_size,
)


__all__ = [
"DeepSeekConfig",
Expand Down
65 changes: 15 additions & 50 deletions nemo/collections/llm/modelopt/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import os
import pprint
import shutil
import tempfile
from dataclasses import dataclass
Expand All @@ -29,6 +30,7 @@
from nemo.collections import llm
from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate
from nemo.collections.llm.modelopt.quantization.quant_cfg_choices import get_quant_cfg_choices
from nemo.collections.llm.modelopt.quantization.utils import load_quant_cfg
from nemo.collections.llm.utils import barrier, torch_dtype_from_precision
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
Expand Down Expand Up @@ -121,11 +123,8 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor

self.quantization_config = quantization_config
self.export_config = export_config

algorithm = quantization_config.algorithm
dtype = export_config.dtype
# Export and Quantization config sanity checks
assert algorithm is None or algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization algorithm: {algorithm}"
if quantization_config.enable_kv_cache:
assert (
quantization_config.kv_cache_qformat in KV_QUANT_CFG_CHOICES
Expand Down Expand Up @@ -235,12 +234,15 @@ def huggingface_forward_loop(model):

def _get_quant_cfg(self, model):
decoder_type = self._get_decoder_type(model, optional=True)
assert (
self.quantization_config.algorithm in QUANT_CFG_CHOICES
), f"Unsupported quantization format: {self.quantization_config.algorithm}"
algorithm = self.quantization_config.algorithm

if os.path.isfile(algorithm):
return load_quant_cfg(algorithm)

quant_cfg = QUANT_CFG_CHOICES[self.quantization_config.algorithm]
if "awq" in self.quantization_config.algorithm:
assert algorithm in QUANT_CFG_CHOICES, f"Unsupported quantization format: {algorithm}"

quant_cfg = QUANT_CFG_CHOICES[algorithm]
if "awq" in algorithm:
quant_cfg = copy.deepcopy(quant_cfg)
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
Expand All @@ -250,11 +252,11 @@ def _get_quant_cfg(self, model):
weight_quantizer["block_sizes"][-1] = self.quantization_config.awq_block_size

# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if "w4a8_awq" == self.quantization_config.algorithm and decoder_type in ["gemma", "mpt"]:
if "w4a8_awq" == algorithm and decoder_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}

if self.quantization_config.enable_kv_cache is None:
enable_quant_kv_cache = "int8" not in self.quantization_config.algorithm and decoder_type != "gpt"
enable_quant_kv_cache = "int8" not in algorithm and decoder_type != "gpt"
else:
enable_quant_kv_cache = self.quantization_config.enable_kv_cache
if self.quantization_config.enable_kv_cache is None and enable_quant_kv_cache:
Expand All @@ -276,7 +278,7 @@ def _get_quant_cfg(self, model):
quant_cfg["algorithm"] = "max"

# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if decoder_type == "gemma" and "int8_sq" in self.quantization_config.algorithm:
if decoder_type == "gemma" and "int8_sq" in algorithm:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

return quant_cfg
Expand All @@ -299,6 +301,7 @@ def quantize(self, model: "MegatronParallel", forward_loop=None):
self._setup(model)
decoder_type = self._get_decoder_type(model, optional=True)
quant_cfg = self._get_quant_cfg(model)
logging.info(f"Using quant_cfg:\n{pprint.pformat(quant_cfg)}")
unwrapped_model = mtq.quantize(unwrap_for_modelopt_operations(model), quant_cfg, forward_loop)
if decoder_type == "gpt":
# We found squared_relu may have an under-calibration problem.
Expand Down Expand Up @@ -505,42 +508,6 @@ def _get_iterator():
return _get_iterator


huggingface_model_type_pattern_match = {
"GPT2": "gpt",
"Mllama": "mllama",
"Llama": "llama",
"Mistral": "llama",
"GPTJ": "gptj",
"FalconForCausalLM": "falcon",
"RWForCausalLM": "falcon",
"baichuan": "baichuan",
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
"QWen": "qwen",
"RecurrentGemma": "recurrentgemma",
"Gemma2": "gemma2",
"Gemma3": "gemma3",
"Gemma": "gemma",
"phi3small": "phi3small",
"phi3": "phi3",
"PhiMoEForCausalLM": "phi3",
"phi": "phi",
"TLGv4ForCausalLM": "phi",
"MixtralForCausalLM": "llama",
"ArcticForCausalLM": "llama",
"StarCoder": "gpt",
"Dbrx": "dbrx",
"T5": "t5",
"Bart": "bart",
"GLM": "glm",
"InternLM2ForCausalLM": "internlm",
"ExaoneForCausalLM": "exaone",
"Nemotron": "gpt",
"Deepseek": "deepseek",
"Whisper": "whisper",
}

gpt_model_type = [
(llm.Baichuan2Model, "baichuan"),
(llm.ChatGLMModel, "chatglm"),
Expand Down Expand Up @@ -576,9 +543,7 @@ def get_modelopt_decoder_type(model: Union[llm.GPTModel, llm.HFAutoModelForCausa
Optional[str]: The inferred decoder type or None if no match is found.
"""
if isinstance(model, llm.HFAutoModelForCausalLM):
for k, v in huggingface_model_type_pattern_match.items():
if k.lower() in type(model.model).__name__.lower():
return v
return mte.model_utils.get_model_type(model.model)
else:
for config_class, decoder_type in gpt_model_type:
if isinstance(model, config_class):
Expand Down
58 changes: 58 additions & 0 deletions nemo/collections/llm/modelopt/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from pathlib import Path
from typing import Any, Dict, Union

from nemo.utils.cast_utils import maybe_cast_to_type


def standardize_json_config(quant_cfg: Dict[str, Any]):
"""Standardize the quantization configuration loaded from a JSON file to
ensure compatibility with modelopt. Modifiy the input dictionary in place.
Args:
quant_cfg (Dict[str, Any]): The quantization config dictionary to standardize.
"""
for key, value in quant_cfg.items():
if key == "block_sizes":
value = {maybe_cast_to_type(k, int): v for k, v in value.items()}
quant_cfg[key] = value
elif key in {"num_bits", "scale_bits"} and isinstance(value, list):
quant_cfg[key] = tuple(value)
continue # No further processing needed
if isinstance(value, dict):
standardize_json_config(value)
elif isinstance(value, list):
for x in value:
if isinstance(x, dict):
standardize_json_config(x)


def load_quant_cfg(cfg_path: Union[str, Path]) -> Dict[str, Any]:
"""Load quantization configuration from a JSON file and adjust for
modelopt standards if necessary.
Args:
cfg_path (str): Path to the quantization config JSON file.
Returns:
dict: The loaded quantization configuration.
"""
with open(cfg_path, "r") as f:
quant_cfg = json.load(f)

standardize_json_config(quant_cfg)
return quant_cfg
17 changes: 17 additions & 0 deletions nemo/utils/cast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from contextlib import contextmanager, nullcontext
from typing import Any

import torch

Expand Down Expand Up @@ -100,3 +101,19 @@ def monkeypatched(object, name, patch):
setattr(object, name, patch)
yield object
setattr(object, name, pre_patched_value)


def maybe_cast_to_type(x: Any, type_: type) -> Any:
"""Try to cast a value to int, if it fails, return the original value.

Args:
x (Any): The value to be casted.
type_ (type): The type to cast to, must be a callable.

Returns:
Any: The casted value or the original value if casting fails.
"""
try:
return type_(x)
except Exception:
return x
Loading
Loading