1515from dataclasses import dataclass
1616from functools import partial
1717
18+ import modelopt .torch .prune as mtp
1819import pytorch_lightning as pl
1920from megatron .core import dist_checkpointing
2021
2425from nemo .lightning .io .pl import TrainerContext , ckpt_to_weights_subdir
2526from nemo .utils import logging
2627from nemo .utils .get_rank import is_global_rank_zero
27- from nemo .utils .import_utils import safe_import
28-
29- mtp , HAVE_MODELOPT = safe_import ("modelopt.torch.prune" )
3028
3129SUPPORTED_PRUNING_HPARAMS = {
3230 # Width pruning
3331 "ffn_hidden_size" ,
3432 "hidden_size" ,
3533 "num_attention_heads" ,
3634 "num_query_groups" ,
35+ "mamba_num_heads" ,
36+ "mamba_head_dim" ,
3737 # Depth pruning
3838 "num_layers" ,
3939}
@@ -50,6 +50,8 @@ class PruningConfig:
5050 Required if `target_num_query_groups` is provided.
5151 target_num_query_groups (int, optional): Target number of query groups for grouped-query attention.
5252 Required if `target_num_attention_heads` is provided.
53+ target_mamba_num_heads (int, optional): Target number of Mamba attention heads.
54+ target_mamba_head_dim (int, optional): Target dimension of Mamba attention heads.
5355 target_num_layers (int, optional): Target number of transformer layers using importance metric.
5456 drop_layers (list[int], optional): List of specific layer indices (1-indexed) to drop from the model.
5557 Cannot be used with other pruning parameters.
@@ -59,6 +61,8 @@ class PruningConfig:
5961 target_hidden_size : int | None = None
6062 target_num_attention_heads : int | None = None
6163 target_num_query_groups : int | None = None
64+ target_mamba_num_heads : int | None = None
65+ target_mamba_head_dim : int | None = None
6266 target_num_layers : int | None = None
6367 drop_layers : list [int ] | None = None
6468
@@ -69,19 +73,21 @@ def __post_init__(self):
6973 self .target_hidden_size ,
7074 self .target_num_attention_heads ,
7175 self .target_num_query_groups ,
76+ self .target_mamba_num_heads ,
77+ self .target_mamba_head_dim ,
7278 self .target_num_layers ,
7379 ]
7480 if any (p is not None for p in other_params ):
7581 raise ValueError ("drop_layers cannot be used with other pruning parameters" )
7682
7783
78- def prune_gpt_model (
84+ def prune_language_model (
7985 model : llm .GPTModel ,
8086 pruning_config : PruningConfig ,
8187 data_module : pl .LightningDataModule | None = None ,
8288 trainer : nl .Trainer | None = None ,
8389) -> llm .GPTModel :
84- """Prune a GPT model in-place based on the provided pruning configuration.
90+ """Prune a GPT / Mamba (sub-class of GPT) model in-place based on the provided pruning configuration.
8591
8692 Args:
8793 model (llm.GPTModel): The model to prune.
@@ -94,9 +100,8 @@ def prune_gpt_model(
94100 Returns:
95101 llm.GPTModel: The pruned model.
96102 """
97- assert HAVE_MODELOPT , "nvidia-modelopt is required to prune the model."
98103 if pruning_config .drop_layers :
99- mtp .plugins .drop_mcore_gpt_layers (model , layers_to_drop = pruning_config .drop_layers )
104+ mtp .plugins .drop_mcore_language_model_layers (model , layers_to_drop = pruning_config .drop_layers )
100105 else :
101106 assert data_module is not None , "data_module is required to prune the model."
102107 assert trainer is not None , "trainer is required to prune the model."
@@ -111,7 +116,7 @@ def prune_gpt_model(
111116 }
112117 mtp .prune (
113118 model ,
114- mode = "mcore_gpt_minitron " ,
119+ mode = "mcore_minitron " ,
115120 constraints = {"export_config" : export_config },
116121 dummy_input = None , # Not used
117122 config = {"forward_loop" : partial (llm .validate , data = data_module , trainer = trainer , tokenizer = "model" )},
0 commit comments