1414
1515import atexit
1616import functools
17- import gc
1817import inspect
1918import logging as _logging
2019import os
5655from megatron .core .dist_checkpointing .validation import StrictHandling
5756from megatron .core .distributed import DistributedDataParallelConfig
5857from megatron .core .optimizer import OptimizerConfig
59- from megatron .training .training import cuda_graph_capture , cuda_graph_set_manual_hooks
6058from torch import nn
6159from torch .distributed .algorithms .ddp_comm_hooks .debugging_hooks import noop_hook
6260from torch .distributed .checkpoint .utils import CheckpointException
@@ -574,10 +572,6 @@ def setup_distributed(self) -> None:
574572 """Setups dist env"""
575573 setup_parallel_ranks (self )
576574
577- # Capture the external cudagraph on a side stream
578- if hasattr (self .model , 'config' ) and getattr (self .model .config , 'external_cuda_graph' , False ):
579- torch .cuda .set_stream (torch .cuda .Stream ())
580-
581575 # Implementation from superclass copied below in order to pass the store to the process group init
582576 reset_seed ()
583577 self .set_world_ranks ()
@@ -726,35 +720,6 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
726720 assert self .lightning_module is not None
727721 assert isinstance (self .model , MegatronParallel )
728722
729- # Capture the external cuda graph for the first step
730- if (
731- self .trainer .global_step == 0
732- and hasattr (self .model , 'config' )
733- and getattr (self .model .config , 'external_cuda_graph' , False )
734- ):
735- # disable the prehook
736- if self .ddp_config .use_distributed_optimizer and self .ddp_config .overlap_param_gather :
737- self .model .disable_forward_pre_hook ()
738- param_sync_func = self .model .config .param_sync_func
739- self .model .config .param_sync_func = None
740- import argparse
741-
742- partial_cg_args = argparse .Namespace ()
743- partial_cg_args .position_embedding_type = self .model .config .position_embedding_type
744- partial_cg_args .seq_length = self .trainer .datamodule .seq_length
745- partial_cg_args .micro_batch_size = self .trainer .datamodule .micro_batch_size
746- cuda_graph_capture (self .model , self .model .config , partial_cg_args )
747-
748- # Set grad to zero.
749- for model_chunk in self .model :
750- model_chunk .zero_grad_buffer ()
751- for opt in self .optimizers :
752- opt .zero_grad ()
753-
754- # Collect garbage and empty unused memory.
755- gc .collect ()
756- torch .cuda .empty_cache ()
757-
758723 with self .precision_plugin .train_step_context (): # TODO: Do we need this?
759724 # Set grad to zero.
760725 for model_chunk in self .model :
@@ -775,18 +740,6 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
775740
776741 reduced_train_loss = out ["loss" ]
777742
778- if (
779- self .trainer .global_step == 0
780- and hasattr (self .model , 'config' )
781- and getattr (self .model .config , 'external_cuda_graph' , False )
782- ):
783- if self .ddp_config .use_distributed_optimizer and self .ddp_config .overlap_param_gather :
784- # enable the prehook
785- self .model .enable_forward_pre_hook ()
786- self .model .config .param_sync_func = param_sync_func
787- param_sync_func = None
788- cuda_graph_set_manual_hooks (self .model )
789-
790743 self .lightning_module .log (
791744 "global_step" ,
792745 self .trainer .global_step ,
0 commit comments