Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions nemo/collections/llm/gpt/model/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
if TYPE_CHECKING:
from megatron.core.transformer import ModuleSpec
from transformers import AutoModelForCausalLM
from transformers import DeepseekV3Config as HFDeepseekV3Config

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
Expand Down Expand Up @@ -621,6 +622,44 @@ def _modify_source_state(self, source: Dict[str, Any], source_config: Dict[str,
def tokenizer(self) -> 'AutoTokenizer':
return io.load_context(self, subpath="model").tokenizer

@property
def config(self) -> "HFDeepseekV3Config":
from transformers import DeepseekV3Config as HFDeepseekV3Config

# TODO: Generalize that to different DeepSeek model variants
source: DeepSeekV3Config = io.load_context(str(self)).model.config

# Figure out the number of zeros in the prefix of moe_layer_freq array
# for the HF first_k_dense_replace parameter and validate the reminder:
i = 0
while i < len(source.moe_layer_freq) and source.moe_layer_freq[i] == 0:
i += 1
assert all(x == 1 for x in source.moe_layer_freq[i:])

return HFDeepseekV3Config(
architectures=["DeepseekV3ForCausalLM"],
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
intermediate_size=source.ffn_hidden_size,
num_attention_heads=source.num_attention_heads,
q_lora_rank=source.q_lora_rank,
qk_nope_head_dim=source.qk_head_dim,
qk_rope_head_dim=source.qk_pos_emb_head_dim,
v_head_dim=source.v_head_dim,
kv_lora_rank=source.kv_lora_rank,
num_key_value_heads=source.kv_channels,
n_routed_experts=source.num_moe_experts,
moe_intermediate_size=source.moe_ffn_hidden_size,
first_k_dense_replace=i,
num_experts_per_tok=source.moe_router_topk,
n_group=source.moe_router_num_groups,
topk_group=source.moe_router_group_topk,
routed_scaling_factor=source.moe_router_topk_scaling_factor,
aux_loss_alpha=source.moe_aux_loss_coeff,
max_position_embeddings=source.max_position_embeddings,
vocab_size=self.tokenizer.vocab_size,
)


__all__ = [
"DeepSeekConfig",
Expand Down
48 changes: 11 additions & 37 deletions nemo/collections/llm/modelopt/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

import copy
import os
import pprint
import shutil
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import torch
from datasets import load_dataset
Expand All @@ -37,6 +38,7 @@
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.import_utils import safe_import
from nemo.utils.model_utils import unwrap_model
from nemo.utils.python.dict_utils import update_nested_dict

if TYPE_CHECKING:
from nemo.lightning import Trainer
Expand Down Expand Up @@ -71,6 +73,7 @@ class QuantizationConfig:
sq_alpha: float = 0.5
enable_kv_cache: Optional[bool] = None
kv_cache_qformat: str = "fp8"
quant_cfg_overrides: Optional[List[str]] = None

calibration_dataset: str = "cnn_dailymail"
calibration_dataset_size: int = 512
Expand Down Expand Up @@ -279,6 +282,12 @@ def _get_quant_cfg(self, model):
if decoder_type == "gemma" and "int8_sq" in self.quantization_config.algorithm:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}

# Apply any other updates given as a list of strings like "key1.key2=value"
if self.quantization_config.quant_cfg_overrides:
for update in self.quantization_config.quant_cfg_overrides:
update_nested_dict(quant_cfg, update)

logging.info(f"Using quant_cfg:\n{pprint.pformat(quant_cfg)}")
return quant_cfg

def quantize(self, model: "MegatronParallel", forward_loop=None):
Expand Down Expand Up @@ -505,41 +514,6 @@ def _get_iterator():
return _get_iterator


huggingface_model_type_pattern_match = {
"GPT2": "gpt",
"Mllama": "mllama",
"Llama": "llama",
"Mistral": "llama",
"GPTJ": "gptj",
"FalconForCausalLM": "falcon",
"RWForCausalLM": "falcon",
"baichuan": "baichuan",
"MPT": "mpt",
"Bloom": "bloom",
"ChatGLM": "chatglm",
"QWen": "qwen",
"RecurrentGemma": "recurrentgemma",
"Gemma2": "gemma2",
"Gemma": "gemma",
"phi3small": "phi3small",
"phi3": "phi3",
"PhiMoEForCausalLM": "phi3",
"phi": "phi",
"TLGv4ForCausalLM": "phi",
"MixtralForCausalLM": "llama",
"ArcticForCausalLM": "llama",
"StarCoder": "gpt",
"Dbrx": "dbrx",
"T5": "t5",
"Bart": "bart",
"GLM": "glm",
"InternLM2ForCausalLM": "internlm",
"ExaoneForCausalLM": "exaone",
"Nemotron": "gpt",
"Deepseek": "deepseek",
"Whisper": "whisper",
}

gpt_model_type = [
(llm.Baichuan2Model, "baichuan"),
(llm.ChatGLMModel, "chatglm"),
Expand Down Expand Up @@ -574,7 +548,7 @@ def get_modelopt_decoder_type(model: Union[llm.GPTModel, llm.HFAutoModelForCausa
Optional[str]: The inferred decoder type or None if no match is found.
"""
if isinstance(model, llm.HFAutoModelForCausalLM):
for k, v in huggingface_model_type_pattern_match.items():
for k, v in mte.model_utils.MODEL_NAME_TO_TYPE.items():
if k.lower() in type(model.model).__name__.lower():
return v
else:
Expand Down
Empty file added nemo/utils/python/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions nemo/utils/python/dict_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ast
from typing import Any, Dict


def update_nested_dict(dict_: Dict[str, Any], update: str, separator: str = '.') -> None:
"""
Update a nested dictionary in-place with a key path and a new value. The update string should be in
the format 'key1.key2.key3=new_value' for the default dot separator -- which hence must not be a part
of any key. The new value will be converted to its appropriate type using ast.literal_eval.

Args:
dict_ (Dict[str, Any]): The dictionary to update.
update (str): The update string in the format 'key1.key2.key3=new_value' (for separator=".").
separator (str): The separator used to split the key path. Default is '.'.
"""
# Split the update string into key path and new value
assert update.count("=") == 1, "Update string must contain exactly one '=' to separate key path and value."
key_path, value = update.split("=")
keys = key_path.split(separator)

# Traverse the nested dictionary & update the final key with new value
current_dict = dict_
for key in keys[:-1]:
current_dict = current_dict[key]

last_key = keys[-1]
current_dict[last_key] = ast.literal_eval(value)
4 changes: 4 additions & 0 deletions scripts/llm/ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def get_args():
"--trust_remote_code", help="Trust remote code when loading HuggingFace models", action="store_true"
)
parser.add_argument("--legacy_ckpt", help="Load ckpt saved with TE < 1.14", action="store_true")
parser.add_argument(
"--quant_cfg_overrides", nargs="*", help="List of 'key1.key2=value' overrides to apply for quant_cfg."
)
args = parser.parse_args()

if args.export_path is None:
Expand All @@ -137,6 +140,7 @@ def main():
sq_alpha=args.sq_alpha,
enable_kv_cache=args.enable_kv_cache,
kv_cache_qformat=args.kv_cache_qformat,
quant_cfg_overrides=args.quant_cfg_overrides,
calibration_dataset=args.calibration_dataset,
calibration_dataset_size=args.calibration_dataset_size,
calibration_batch_size=args.batch_size,
Expand Down
Loading