|
13 | 13 | # limitations under the License. |
14 | 14 | import json |
15 | 15 | import re |
16 | | -from dataclasses import dataclass, field |
| 16 | +from dataclasses import asdict, dataclass, field |
17 | 17 | from functools import cached_property, partial |
18 | 18 | from pathlib import Path |
19 | 19 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union |
|
45 | 45 | if TYPE_CHECKING: |
46 | 46 | from megatron.core.transformer import ModuleSpec |
47 | 47 | from transformers import AutoModelForCausalLM |
| 48 | + from transformers import DeepseekV3Config as HFDeepseekV3Config |
48 | 49 |
|
49 | 50 | from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer |
50 | 51 | from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
@@ -472,6 +473,28 @@ def init(self, dtype=torch.bfloat16, model_name="deepseek-ai/DeepSeek-V3") -> "A |
472 | 473 | type(hf_model).register_for_auto_class("AutoModelForCausalLM") |
473 | 474 | return hf_model |
474 | 475 |
|
| 476 | + def _detect_hf_deepseek_version(self, source_config: Dict[str, Any]) -> str: |
| 477 | + """ |
| 478 | + Detect the HF DeepSeek version based on the source NeMo config. |
| 479 | +
|
| 480 | + Args: |
| 481 | + source_config (Dict[str, Any]): The source NeMo model config. |
| 482 | +
|
| 483 | + Returns: |
| 484 | + str: The DeepSeek version in the Hugging Face Hub convention. |
| 485 | + """ |
| 486 | + if source_config['moe_router_enable_expert_bias']: |
| 487 | + target_model_name = "deepseek-ai/DeepSeek-V3" |
| 488 | + elif source_config['q_lora_rank'] is not None: |
| 489 | + target_model_name = "deepseek-ai/DeepSeek-V2" |
| 490 | + else: |
| 491 | + target_model_name = "deepseek-ai/DeepSeek-V2-Lite" |
| 492 | + logging.info( |
| 493 | + f"Your model is determined to be {target_model_name} based on the config. If this is not correct, " |
| 494 | + f"please pass in a local HF checkpoint." |
| 495 | + ) |
| 496 | + return target_model_name |
| 497 | + |
475 | 498 | def ckpt_load(self, path: Path) -> Tuple[Dict, Dict]: |
476 | 499 | """ |
477 | 500 | This function loads the state dict directly from a distributed checkpoint, and modify the state dict |
@@ -511,21 +534,12 @@ def apply(self, output_path: Path, target_model_name=None) -> Path: |
511 | 534 | logging.info("DeepSeek NeMo checkpoint loaded.") |
512 | 535 | if target_model_name is None: |
513 | 536 | # Before DeepSeek is fully supported by HF, it is necessary to pass in a local HF checkpoint that |
514 | | - # is used to initialize the HF model. The following |
| 537 | + # is used to initialize the HF model. |
515 | 538 | logging.warning( |
516 | 539 | "Before DeepSeek is officially supported in HF, you should pass in a local HF " |
517 | 540 | "checkpoint using llm.export_ckpt(..., target_model_name=<local hf path>)" |
518 | 541 | ) |
519 | | - if source_config['moe_router_enable_expert_bias']: |
520 | | - target_model_name = "deepseek-ai/DeepSeek-V3" |
521 | | - elif source_config['q_lora_rank'] is not None: |
522 | | - target_model_name = "deepseek-ai/DeepSeek-V2" |
523 | | - else: |
524 | | - target_model_name = "deepseek-ai/DeepSeek-V2-Lite" |
525 | | - logging.info( |
526 | | - f"Your model is determined to be {target_model_name} based on the config. If this is not correct, " |
527 | | - f"please pass in a local HF checkpoint." |
528 | | - ) |
| 542 | + target_model_name = self._detect_hf_deepseek_version(source_config) |
529 | 543 |
|
530 | 544 | target = self.init(torch_dtype_from_dict_config(source_config), model_name=target_model_name) |
531 | 545 | target = self.convert_state(source, target, source_config) |
@@ -639,6 +653,60 @@ def _modify_source_state(self, source: Dict[str, Any], source_config: Dict[str, |
639 | 653 | def tokenizer(self) -> 'AutoTokenizer': |
640 | 654 | return io.load_context(self, subpath="model").tokenizer |
641 | 655 |
|
| 656 | + @property |
| 657 | + def config(self) -> "HFDeepseekV3Config": |
| 658 | + """Create a HF DeepseekV3Config from the NeMo model config. |
| 659 | +
|
| 660 | + Translates the NeMo configuration parameters to the equivalent HF |
| 661 | + configuration. |
| 662 | +
|
| 663 | + Currently only supports DeepseekV3Config based on availability |
| 664 | + in the Transformers library. |
| 665 | +
|
| 666 | + Returns: |
| 667 | + HFDeepseekV3Config: HF configuration for DeepSeekV3 models |
| 668 | + """ |
| 669 | + # TODO: Get config for all DeepSeek model variants once available in transformers |
| 670 | + |
| 671 | + from transformers import DeepseekV3Config as HFDeepseekV3Config |
| 672 | + |
| 673 | + source: DeepSeekV3Config = io.load_context(str(self)).model.config |
| 674 | + |
| 675 | + target_model_name = self._detect_hf_deepseek_version(asdict(source)) |
| 676 | + if target_model_name != "deepseek-ai/DeepSeek-V3": |
| 677 | + raise ValueError(f"Getting config for model other than {target_model_name} is not supported.") |
| 678 | + |
| 679 | + # Figure out the number of zeros in the prefix of moe_layer_freq array |
| 680 | + # for the HF first_k_dense_replace parameter and validate the reminder: |
| 681 | + k = 0 |
| 682 | + while k < len(source.moe_layer_freq) and source.moe_layer_freq[k] == 0: |
| 683 | + k += 1 |
| 684 | + assert all(x == 1 for x in source.moe_layer_freq[k:]) |
| 685 | + |
| 686 | + return HFDeepseekV3Config( |
| 687 | + architectures=["DeepseekV3ForCausalLM"], |
| 688 | + num_hidden_layers=source.num_layers, |
| 689 | + hidden_size=source.hidden_size, |
| 690 | + intermediate_size=source.ffn_hidden_size, |
| 691 | + num_attention_heads=source.num_attention_heads, |
| 692 | + q_lora_rank=source.q_lora_rank, |
| 693 | + qk_nope_head_dim=source.qk_head_dim, |
| 694 | + qk_rope_head_dim=source.qk_pos_emb_head_dim, |
| 695 | + v_head_dim=source.v_head_dim, |
| 696 | + kv_lora_rank=source.kv_lora_rank, |
| 697 | + num_key_value_heads=source.kv_channels, |
| 698 | + n_routed_experts=source.num_moe_experts, |
| 699 | + moe_intermediate_size=source.moe_ffn_hidden_size, |
| 700 | + first_k_dense_replace=k, |
| 701 | + num_experts_per_tok=source.moe_router_topk, |
| 702 | + n_group=source.moe_router_num_groups, |
| 703 | + topk_group=source.moe_router_group_topk, |
| 704 | + routed_scaling_factor=source.moe_router_topk_scaling_factor, |
| 705 | + aux_loss_alpha=source.moe_aux_loss_coeff, |
| 706 | + max_position_embeddings=source.max_position_embeddings, |
| 707 | + vocab_size=self.tokenizer.vocab_size, |
| 708 | + ) |
| 709 | + |
642 | 710 |
|
643 | 711 | __all__ = [ |
644 | 712 | "DeepSeekConfig", |
|
0 commit comments