Skip to content

Commit 251fe60

Browse files
kevalmorabia97chtruong814AAnoosheh
committed
Bump modelopt to 0.35.0 and remove safe_import("modelopt") in llm collection (#14656)
* Bump modelopt to 0.35.0 and remove safe_import in llm collection Signed-off-by: Keval Morabia <[email protected]> * Update eagle architecture spec setting Signed-off-by: Asha Anoosheh <[email protected]> * Reduce specdec memory usage Signed-off-by: Asha Anoosheh <[email protected]> --------- Signed-off-by: Keval Morabia <[email protected]> Signed-off-by: Asha Anoosheh <[email protected]> Co-authored-by: Charlie Truong <[email protected]> Co-authored-by: Asha Anoosheh <[email protected]> Signed-off-by: Charlie Truong <[email protected]>
1 parent 1b06710 commit 251fe60

13 files changed

Lines changed: 87 additions & 65 deletions

File tree

docker/common/install_dep.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ extra() {
304304
"llama-index==0.10.43" # incompatible with nvidia-pytriton
305305
"ctc_segmentation==1.7.1 ; (platform_machine == 'x86_64' and platform_system != 'Darwin')" # requires numpy<2.0.0 to be installed before
306306
"nemo_run"
307-
"nvidia-modelopt[torch]==0.33.0 ; platform_system != 'Darwin'" # We want a specific version of nvidia-modelopt
307+
"nvidia-modelopt==0.35.0" # We want a specific version of nvidia-modelopt
308308
)
309309
if [[ "${NVIDIA_PYTORCH_VERSION}" != "" ]]; then
310310
DEPS+=(

nemo/collections/llm/api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
PruningConfig,
4343
QuantizationConfig,
4444
Quantizer,
45-
prune_gpt_model,
45+
prune_language_model,
4646
save_pruned_model,
4747
set_modelopt_spec_if_exists_in_ckpt,
4848
setup_trainer_and_restore_model_with_modelopt_spec,
@@ -310,6 +310,8 @@ def prune(
310310
num_nodes: int = 1,
311311
tp_size: int = 1,
312312
pp_size: int = 1,
313+
num_layers_in_first_pipeline_stage: int | None = None,
314+
num_layers_in_last_pipeline_stage: int | None = None,
313315
num_train_samples: int = 1024,
314316
data: pl.LightningDataModule | None = None,
315317
tokenizer_path: str | None = None,
@@ -327,6 +329,8 @@ def prune(
327329
tp_size (int): The tensor parallel size.
328330
pp_size (int): The pipeline parallel size.
329331
num_train_samples (int): Number of training samples for importance estimation using forward pass.
332+
num_layers_in_first_pipeline_stage (int): The number of layers in the first pipeline stage.
333+
num_layers_in_last_pipeline_stage (int): The number of layers in the last pipeline stage.
330334
data (pl.LightningDataModule): The data module for forward pass.
331335
Required if not dropping layers.
332336
tokenizer_path (str): Path to the tokenizer if not using model's tokenizer.
@@ -362,6 +366,8 @@ def prune(
362366
model_path=nemo_checkpoint,
363367
tensor_model_parallel_size=tp_size,
364368
pipeline_model_parallel_size=pp_size,
369+
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
370+
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,
365371
devices=devices,
366372
num_nodes=num_nodes,
367373
inference_only=True,
@@ -371,7 +377,7 @@ def prune(
371377
trainer_kwargs={"max_steps": steps, "limit_val_batches": steps, "val_check_interval": steps},
372378
model_config_overrides={"sequence_parallel": False},
373379
)
374-
prune_gpt_model(model, pruning_config, data, trainer)
380+
prune_language_model(model, pruning_config, data, trainer)
375381
save_pruned_model(trainer, save_path)
376382

377383
console = Console()

nemo/collections/llm/modelopt/distill/model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
2424
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
2525
from nemo.utils import logging
26-
from nemo.utils.import_utils import safe_import
2726
from nemo.utils.model_utils import unwrap_model
2827

2928
from .utils import adjust_distillation_model_for_mcore, load_distillation_config, teacher_provider
@@ -32,7 +31,7 @@
3231
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
3332
from nemo.lightning.pytorch.optim import OptimizerModule
3433

35-
mtd, HAVE_MODELOPT = safe_import("modelopt.torch.distill")
34+
import modelopt.torch.distill as mtd
3635

3736

3837
class _DistillationLossReduction(MaskedTokenLossReduction):
@@ -134,8 +133,6 @@ def __init__(
134133
tokenizer: Tokenizer.
135134
model_transform: Transform to apply to model during setup.
136135
"""
137-
if not HAVE_MODELOPT:
138-
raise RuntimeError("nvidia-modelopt is needed to use DistillationGPTModel")
139136
super().__init__(config, optim, tokenizer, model_transform)
140137
self._teacher_config = teacher_config
141138
self._teacher_ckpt_path = teacher_ckpt_path

nemo/collections/llm/modelopt/distill/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from nemo import lightning as nl
2929
from nemo.collections import llm
3030
from nemo.utils import logging
31-
from nemo.utils.import_utils import safe_import, safe_import_from
3231

3332
from .loss import HiddenStateCosineLoss, LogitsAndIntermediatesLossBalancer, LogitsKLLoss, ProjectionLayer
3433

@@ -39,9 +38,8 @@
3938

4039
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
4140

42-
mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt")
43-
DistillationModel, _ = safe_import_from("modelopt.torch.distill", "DistillationModel", alt=object)
44-
DistillationLossBalancer, _ = safe_import_from("modelopt.torch.distill", "DistillationLossBalancer", alt=object)
41+
import modelopt.torch.opt as mto
42+
from modelopt.torch.distill import DistillationLossBalancer, DistillationModel
4543

4644

4745
@dataclass
@@ -242,8 +240,6 @@ def get_tensor_shapes_adjust_fn_for_distillation(
242240
Currently only used during non-interleaved pipelining for Distillation.
243241
Concatenates sizes of student and teacher output tensors for inter-process communication.
244242
"""
245-
if not HAVE_MODELOPT:
246-
return None
247243
if (
248244
forward_only
249245
or parallel_state.get_pipeline_model_parallel_world_size() == 1

nemo/collections/llm/modelopt/model_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import TYPE_CHECKING, Callable, Optional, Union
1919

2020
import lightning.pytorch as L
21+
import modelopt.torch.opt as mto
2122
import torch
2223
import torch.nn as nn
2324
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
@@ -32,8 +33,6 @@
3233
from nemo.utils.import_utils import safe_import
3334
from nemo.utils.model_utils import unwrap_model
3435

35-
mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt")
36-
3736
_, HAVE_TE = safe_import("transformer_engine")
3837
if HAVE_TE:
3938
# These custom modelopt specs are a mix of local MCORE and TE specs.
@@ -214,8 +213,6 @@ def restore_modelopt_state(
214213
path (str): The path to the checkpoint.
215214
trainer (pl.Trainer): The trainer object, in case path not provided.
216215
"""
217-
if not HAVE_MODELOPT:
218-
return
219216
if not path:
220217
if trainer is None:
221218
return
@@ -254,9 +251,6 @@ def save_modelopt_state(model: "MegatronParallel", path: str, checkpoint_io: "Ch
254251
path (str): The path to the checkpoint.
255252
checkpoint_io (CheckpointIO): The checkpoint IO object from MegatronStrategy.
256253
"""
257-
if not HAVE_MODELOPT:
258-
return
259-
260254
# Save ModelOpt state too, if it exists.
261255
core_model = unwrap_model(model)
262256
if not mto.ModeloptStateManager.is_converted(core_model):

nemo/collections/llm/modelopt/prune/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414

1515
"""Prune utilities for using TensorRT Model Optimizer."""
1616

17-
from .pruner import PruningConfig, prune_gpt_model, save_pruned_model
17+
from .pruner import PruningConfig, prune_language_model, save_pruned_model
1818

19-
__all__ = ["PruningConfig", "prune_gpt_model", "save_pruned_model"]
19+
__all__ = ["PruningConfig", "prune_language_model", "save_pruned_model"]

nemo/collections/llm/modelopt/prune/pruner.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import dataclass
1616
from functools import partial
1717

18+
import modelopt.torch.prune as mtp
1819
import pytorch_lightning as pl
1920
from megatron.core import dist_checkpointing
2021

@@ -24,16 +25,15 @@
2425
from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir
2526
from nemo.utils import logging
2627
from nemo.utils.get_rank import is_global_rank_zero
27-
from nemo.utils.import_utils import safe_import
28-
29-
mtp, HAVE_MODELOPT = safe_import("modelopt.torch.prune")
3028

3129
SUPPORTED_PRUNING_HPARAMS = {
3230
# Width pruning
3331
"ffn_hidden_size",
3432
"hidden_size",
3533
"num_attention_heads",
3634
"num_query_groups",
35+
"mamba_num_heads",
36+
"mamba_head_dim",
3737
# Depth pruning
3838
"num_layers",
3939
}
@@ -50,6 +50,8 @@ class PruningConfig:
5050
Required if `target_num_query_groups` is provided.
5151
target_num_query_groups (int, optional): Target number of query groups for grouped-query attention.
5252
Required if `target_num_attention_heads` is provided.
53+
target_mamba_num_heads (int, optional): Target number of Mamba attention heads.
54+
target_mamba_head_dim (int, optional): Target dimension of Mamba attention heads.
5355
target_num_layers (int, optional): Target number of transformer layers using importance metric.
5456
drop_layers (list[int], optional): List of specific layer indices (1-indexed) to drop from the model.
5557
Cannot be used with other pruning parameters.
@@ -59,6 +61,8 @@ class PruningConfig:
5961
target_hidden_size: int | None = None
6062
target_num_attention_heads: int | None = None
6163
target_num_query_groups: int | None = None
64+
target_mamba_num_heads: int | None = None
65+
target_mamba_head_dim: int | None = None
6266
target_num_layers: int | None = None
6367
drop_layers: list[int] | None = None
6468

@@ -69,19 +73,21 @@ def __post_init__(self):
6973
self.target_hidden_size,
7074
self.target_num_attention_heads,
7175
self.target_num_query_groups,
76+
self.target_mamba_num_heads,
77+
self.target_mamba_head_dim,
7278
self.target_num_layers,
7379
]
7480
if any(p is not None for p in other_params):
7581
raise ValueError("drop_layers cannot be used with other pruning parameters")
7682

7783

78-
def prune_gpt_model(
84+
def prune_language_model(
7985
model: llm.GPTModel,
8086
pruning_config: PruningConfig,
8187
data_module: pl.LightningDataModule | None = None,
8288
trainer: nl.Trainer | None = None,
8389
) -> llm.GPTModel:
84-
"""Prune a GPT model in-place based on the provided pruning configuration.
90+
"""Prune a GPT / Mamba (sub-class of GPT) model in-place based on the provided pruning configuration.
8591
8692
Args:
8793
model (llm.GPTModel): The model to prune.
@@ -94,9 +100,8 @@ def prune_gpt_model(
94100
Returns:
95101
llm.GPTModel: The pruned model.
96102
"""
97-
assert HAVE_MODELOPT, "nvidia-modelopt is required to prune the model."
98103
if pruning_config.drop_layers:
99-
mtp.plugins.drop_mcore_gpt_layers(model, layers_to_drop=pruning_config.drop_layers)
104+
mtp.plugins.drop_mcore_language_model_layers(model, layers_to_drop=pruning_config.drop_layers)
100105
else:
101106
assert data_module is not None, "data_module is required to prune the model."
102107
assert trainer is not None, "trainer is required to prune the model."
@@ -111,7 +116,7 @@ def prune_gpt_model(
111116
}
112117
mtp.prune(
113118
model,
114-
mode="mcore_gpt_minitron",
119+
mode="mcore_minitron",
115120
constraints={"export_config": export_config},
116121
dummy_input=None, # Not used
117122
config={"forward_loop": partial(llm.validate, data=data_module, trainer=trainer, tokenizer="model")},

nemo/collections/llm/modelopt/quantization/quant_cfg_choices.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
from typing import Any, Dict
1616

17-
from nemo.utils.import_utils import safe_import
18-
19-
mtq, HAVE_MODELOPT = safe_import("modelopt.torch.quantization")
17+
import modelopt.torch.quantization as mtq
2018

2119

2220
def get_quant_cfg_choices() -> Dict[str, Dict[str, Any]]:
@@ -32,9 +30,6 @@ def get_quant_cfg_choices() -> Dict[str, Dict[str, Any]]:
3230
dict: A dictionary where keys are short names (e.g., "fp8") and values are the
3331
corresponding modelopt quantization configuration objects.
3432
"""
35-
if not HAVE_MODELOPT:
36-
return {}
37-
3833
QUANT_CFG_NAMES = [
3934
("int8", "INT8_DEFAULT_CFG"),
4035
("int8_sq", "INT8_SMOOTHQUANT_CFG"),

nemo/collections/llm/modelopt/quantization/quantizer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from pathlib import Path
2222
from typing import TYPE_CHECKING, Optional, Union
2323

24+
import modelopt.torch.export as mte
25+
import modelopt.torch.opt as mto
26+
import modelopt.torch.quantization as mtq
2427
import torch
2528
from datasets import load_dataset
2629
from megatron.core.inference.common_inference_params import CommonInferenceParams
@@ -37,7 +40,6 @@
3740
from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir
3841
from nemo.utils import logging
3942
from nemo.utils.get_rank import is_global_rank_zero
40-
from nemo.utils.import_utils import safe_import
4143
from nemo.utils.model_utils import unwrap_model
4244

4345
if TYPE_CHECKING:
@@ -46,10 +48,6 @@
4648
from nemo.lightning import Trainer
4749
from nemo.lightning.megatron_parallel import MegatronParallel
4850

49-
mte, HAVE_MODELOPT_MTE = safe_import("modelopt.torch.export")
50-
mtq, HAVE_MODELOPT_MTQ = safe_import("modelopt.torch.quantization")
51-
mto, HAVE_MODELOPT_MTO = safe_import("modelopt.torch.opt")
52-
HAVE_MODELOPT = HAVE_MODELOPT_MTQ and HAVE_MODELOPT_MTE and HAVE_MODELOPT_MTO
5351

5452
QUANT_CFG_CHOICES = get_quant_cfg_choices()
5553
SUPPORTED_DTYPE = [16, "16", "bf16"] # Default precision for non-quantized layers
@@ -121,8 +119,6 @@ class Quantizer:
121119

122120
def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig):
123121
"""Initialize Quantizer with quantization and export configurations."""
124-
if not HAVE_MODELOPT:
125-
raise RuntimeError("nvidia-modelopt is needed to use Quantizer")
126122
if not torch.cuda.is_available():
127123
raise EnvironmentError("GPU is required for the quantization.")
128124

nemo/collections/llm/modelopt/speculative/model_transform.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import modelopt.torch.opt as mto
16+
import modelopt.torch.speculative as mtsp
1517
import torch.nn as nn
1618

1719
from nemo.collections.llm import GPTModel
1820
from nemo.utils import logging
19-
from nemo.utils.import_utils import UnavailableError, safe_import
2021
from nemo.utils.model_utils import unwrap_model
2122

22-
mto, HAVE_MODELOPT = safe_import("modelopt.torch.opt")
23-
mtsp, _ = safe_import("modelopt.torch.speculative")
24-
25-
try:
26-
ALGORITHMS = {
27-
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
28-
# more TBD
29-
}
30-
except UnavailableError:
31-
ALGORITHMS = {}
23+
ALGORITHMS = {
24+
"eagle3": mtsp.EAGLE3_DEFAULT_CFG,
25+
# more TBD
26+
}
3227

3328

3429
def apply_speculative_decoding(model: nn.Module, algorithm: str = "eagle3") -> nn.Module:
35-
"""
36-
Transform a model to enable Speculative Decoding using Model Optimizer.
30+
"""Transform a model to enable Speculative Decoding using Model Optimizer.
3731
3832
Args:
3933
model: The model to transform.
@@ -43,9 +37,6 @@ def apply_speculative_decoding(model: nn.Module, algorithm: str = "eagle3") -> n
4337
Returns:
4438
The transformed model.
4539
"""
46-
if not HAVE_MODELOPT:
47-
raise ImportError("nvidia-modelopt is required to use Speculative Decoding")
48-
4940
assert algorithm in ALGORITHMS, f"Invalid algorithm: {algorithm}. Choices: {ALGORITHMS.keys()}"
5041
mode_cfg = ALGORITHMS[algorithm]
5142
mode, cfg = mode_cfg["algorithm"], mode_cfg["config"]
@@ -63,16 +54,26 @@ def apply_speculative_decoding(model: nn.Module, algorithm: str = "eagle3") -> n
6354
if unwrapped_model.config.virtual_pipeline_model_parallel_size is not None:
6455
raise ValueError("Speculative decoding is incompatible with virtual pipeline parallelism.")
6556

66-
logging.info(f"Converting to Speculative Decoding model with mode: {mode} and config:\n{cfg}")
57+
# Adjust decoder head architecture
58+
if "eagle_architecture_config" in cfg:
59+
# These ones are necessary
60+
cfg["eagle_architecture_config"]["hidden_size"] = unwrapped_model.config.hidden_size
61+
cfg["eagle_architecture_config"]["vocab_size"] = unwrapped_model.vocab_size
62+
cfg["eagle_architecture_config"]["draft_vocab_size"] = unwrapped_model.vocab_size
63+
# These ones are optional but we copy base model's to scale memory usage reasonably
64+
cfg["eagle_architecture_config"]["intermediate_size"] = unwrapped_model.config.ffn_hidden_size
65+
cfg["eagle_architecture_config"]["num_attention_heads"] = unwrapped_model.config.num_attention_heads
66+
cfg["eagle_architecture_config"]["num_key_value_heads"] = unwrapped_model.config.num_query_groups
67+
68+
# Convert
69+
logging.info(f"Converting to Speculative Decoding model with mode: '{mode}' and config:\n{cfg}")
6770
mtsp.convert(unwrapped_model, [(mode, cfg)]) # assumes in-place
6871

6972
return model
7073

7174

7275
def _has_same_speculative_decoding_state(model: nn.Module, mode: str) -> bool:
73-
"""
74-
Check if the model has the same Speculative Decoding state as the incoming algorithm mode.
75-
"""
76+
"""Check if the model has the same Speculative Decoding state as the incoming algorithm mode."""
7677
from modelopt.torch.opt.mode import _ModeRegistryCls
7778

7879
mode_registry = _ModeRegistryCls.get_registry_by_name("speculative")

0 commit comments

Comments
 (0)