Skip to content

Commit c1b0c4c

Browse files
chtruong814nasretdinovr
authored andcommitted
Revert "[MoE] Partial Cudagraph support for MoE (NVIDIA-NeMo#14362)" (NVIDIA-NeMo#14402)
This reverts commit 7d361dc. Signed-off-by: Charlie Truong <[email protected]>
1 parent 6ae979c commit c1b0c4c

File tree

2 files changed

+1
-48
lines changed

2 files changed

+1
-48
lines changed

nemo/collections/llm/gpt/model/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None, vp_sta
326326
Returns:
327327
MCoreGPTModel: Configured Megatron Core GPT model instance
328328
"""
329-
if self.enable_cuda_graph or self.external_cuda_graph:
329+
if self.enable_cuda_graph:
330330
assert HAVE_TE, "Transformer Engine is required for cudagraphs."
331331
assert getattr(self, "use_te_rng_tracker", False), (
332332
"Transformer engine's RNG tracker is required for cudagraphs, it can be "

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import atexit
1616
import functools
17-
import gc
1817
import inspect
1918
import logging as _logging
2019
import os
@@ -56,7 +55,6 @@
5655
from megatron.core.dist_checkpointing.validation import StrictHandling
5756
from megatron.core.distributed import DistributedDataParallelConfig
5857
from megatron.core.optimizer import OptimizerConfig
59-
from megatron.training.training import cuda_graph_capture, cuda_graph_set_manual_hooks
6058
from torch import nn
6159
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
6260
from torch.distributed.checkpoint.utils import CheckpointException
@@ -574,10 +572,6 @@ def setup_distributed(self) -> None:
574572
"""Setups dist env"""
575573
setup_parallel_ranks(self)
576574

577-
# Capture the external cudagraph on a side stream
578-
if hasattr(self.model, 'config') and getattr(self.model.config, 'external_cuda_graph', False):
579-
torch.cuda.set_stream(torch.cuda.Stream())
580-
581575
# Implementation from superclass copied below in order to pass the store to the process group init
582576
reset_seed()
583577
self.set_world_ranks()
@@ -726,35 +720,6 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
726720
assert self.lightning_module is not None
727721
assert isinstance(self.model, MegatronParallel)
728722

729-
# Capture the external cuda graph for the first step
730-
if (
731-
self.trainer.global_step == 0
732-
and hasattr(self.model, 'config')
733-
and getattr(self.model.config, 'external_cuda_graph', False)
734-
):
735-
# disable the prehook
736-
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
737-
self.model.disable_forward_pre_hook()
738-
param_sync_func = self.model.config.param_sync_func
739-
self.model.config.param_sync_func = None
740-
import argparse
741-
742-
partial_cg_args = argparse.Namespace()
743-
partial_cg_args.position_embedding_type = self.model.config.position_embedding_type
744-
partial_cg_args.seq_length = self.trainer.datamodule.seq_length
745-
partial_cg_args.micro_batch_size = self.trainer.datamodule.micro_batch_size
746-
cuda_graph_capture(self.model, self.model.config, partial_cg_args)
747-
748-
# Set grad to zero.
749-
for model_chunk in self.model:
750-
model_chunk.zero_grad_buffer()
751-
for opt in self.optimizers:
752-
opt.zero_grad()
753-
754-
# Collect garbage and empty unused memory.
755-
gc.collect()
756-
torch.cuda.empty_cache()
757-
758723
with self.precision_plugin.train_step_context(): # TODO: Do we need this?
759724
# Set grad to zero.
760725
for model_chunk in self.model:
@@ -775,18 +740,6 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
775740

776741
reduced_train_loss = out["loss"]
777742

778-
if (
779-
self.trainer.global_step == 0
780-
and hasattr(self.model, 'config')
781-
and getattr(self.model.config, 'external_cuda_graph', False)
782-
):
783-
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
784-
# enable the prehook
785-
self.model.enable_forward_pre_hook()
786-
self.model.config.param_sync_func = param_sync_func
787-
param_sync_func = None
788-
cuda_graph_set_manual_hooks(self.model)
789-
790743
self.lightning_module.log(
791744
"global_step",
792745
self.trainer.global_step,

0 commit comments

Comments
 (0)