66
77import functools
88import time
9- from typing import Any , Dict , List , Optional , Type
9+ from typing import Any , Dict
1010
1111# Centralized OneLogger import - this is the only place where nv_one_logger should be imported
1212try :
2121from lightning .pytorch import Trainer
2222from lightning .pytorch .callbacks import Callback
2323from lightning .pytorch .core import LightningModule
24- from lightning .pytorch .utilities import rank_zero_only
2524from lightning .pytorch .utilities .types import STEP_OUTPUT
2625
2726# Export OneLogger availability flag
@@ -100,80 +99,63 @@ def __init__(self):
10099 self .track_event ('on_app_start' )
101100
102101 def track_event (self , event_type : str , current_time_ms : float = None ):
103- """Track a timing event with automatic start/end timing.
104-
105- This method automatically handles the start and end timing for an event.
106- It detects start/end events based on event names ending with "_start" or "_end".
102+ """Track an event with optional timestamp.
107103
108104 Args:
109- event_type: Type of event (e.g., 'model_init_start', 'model_init_end', 'dataloader_init')
105+ event_type: The type of event to track
106+ current_time_ms: Optional timestamp in milliseconds
110107 """
111108 if current_time_ms is None :
112109 current_time_ms = get_current_time_msec ()
113110
114- event = {'name' : event_type , 'time_ms' : current_time_ms }
111+ event = {
112+ 'event_type' : event_type ,
113+ 'timestamp' : current_time_ms ,
114+ }
115115
116- if not self ._one_logger_available :
117- self ._pending_events .append (event )
118- else :
116+ if self ._one_logger_available :
119117 self ._log_event (event )
118+ else :
119+ self ._pending_events .append (event )
120120
121121 def set_one_logger_available (self , available : bool = True ):
122122 """Set whether OneLogger is available and process pending events.
123123
124124 Args:
125- available: Whether OneLogger is now available
125+ available: Whether OneLogger is available
126126 """
127127 self ._one_logger_available = available
128128 if available and self ._pending_events :
129- # Process all pending events
130129 for event in self ._pending_events :
131130 self ._log_event (event )
132131 self ._pending_events .clear ()
133132
134133 @classmethod
135134 def mark_one_logger_available (cls ):
136- """Class method to mark OneLogger as available globally."""
137- instance = cls .get_instance ()
138- instance .set_one_logger_available (True )
135+ """Mark OneLogger as available for the singleton instance."""
136+ cls .get_instance ().set_one_logger_available (True )
139137
140138 def _log_event (self , event : Dict [str , Any ]):
141- """Log an event using OneLogger callbacks .
139+ """Log an event to OneLogger if available .
142140
143141 Args:
144- event: Event data containing name, time_ms
142+ event: The event to log
145143 """
146- # If nv-one-logger is not available, or OneLogger is not yet initialized, skip logging
147- if not HAVE_ONELOGGER or not self ._one_logger_available :
148- return
149-
150- # Handle start/end event pairs
151- event_name = event ['name' ]
152- time_ms = event ['time_ms' ]
153-
154- if event_name .endswith ('_start' ):
155- get_onelogger_callbacks (event_name , start_time_msec = time_ms )
156- elif event_name .endswith ('_end' ):
157- get_onelogger_callbacks (event_name , finish_time_msec = time_ms )
158- else :
159- raise ValueError (f"Invalid event name for api: { event_name } " )
144+ if HAVE_ONELOGGER :
145+ get_onelogger_callbacks ("log_event" , event )
160146
161147
162148class OneLoggerNeMoCallback (Callback ):
163- """
164- NeMo callback that integrates with OneLogger v2 for tracking metrics.
165-
166- This callback implements NeMo's callback group API and internally
167- uses OneLogger's training telemetry functionality to track metrics.
168- """
149+ """Callback for integrating OneLogger telemetry with NeMo training."""
169150
170151 def __init__ (self ):
152+ """Initialize the OneLogger callback."""
171153 super ().__init__ ()
172154 self ._validation_batch_exists = False
173155 self ._train_active = False
174156
175157 def __getattr__ (self , name : str ) -> Any :
176- """Automatically forward any undefined method calls to the OneLogger v2 callbacks mainly for non-trainer methods .
158+ """Automatically forward any undefined method calls to the OneLogger v2 callbacks.
177159
178160 This eliminates the need for manually writing pass-through methods for each OneLogger API.
179161 Only methods that need custom logic (like those interacting with the trainer) need to be
@@ -195,10 +177,13 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
195177 max_steps = trainer .max_steps if hasattr (trainer , 'max_steps' ) else 0
196178
197179 get_onelogger_callbacks (
198- "on_train_start" , train_iterations_start = current_step , train_iterations_target_or_fn = max_steps
180+ "on_train_start" ,
181+ train_iterations_start = current_step ,
182+ train_iterations_target_or_fn = max_steps
199183 )
200184
201185 def on_train_batch_start (self , trainer : Trainer , pl_module : LightningModule , batch : Any , batch_idx : int ) -> None :
186+ """Called at the beginning of each training batch."""
202187 get_onelogger_callbacks ("on_training_single_iteration_start" )
203188
204189 def on_train_batch_end (
@@ -209,12 +194,15 @@ def on_train_batch_end(
209194 batch : Any ,
210195 batch_idx : int ,
211196 ) -> None :
197+ """Called at the end of each training batch."""
212198 get_onelogger_callbacks ("on_training_single_iteration_end" )
213199
214200 def on_validation_start (self , trainer : Trainer , pl_module : LightningModule ) -> None :
201+ """Called when validation begins."""
215202 get_onelogger_callbacks ("on_validation_start" )
216203
217204 def on_validation_end (self , trainer : Trainer , pl_module : LightningModule ) -> None :
205+ """Called when validation ends."""
218206 if self ._validation_batch_exists :
219207 get_onelogger_callbacks ("on_validation_single_iteration_end" )
220208 self ._validation_batch_exists = False
@@ -228,6 +216,7 @@ def on_validation_batch_start(
228216 batch_idx : int ,
229217 dataloader_idx : int = 0 ,
230218 ) -> None :
219+ """Called at the beginning of each validation batch."""
231220 if self ._validation_batch_exists :
232221 get_onelogger_callbacks ("on_validation_single_iteration_end" )
233222 self ._validation_batch_exists = True
@@ -242,6 +231,7 @@ def on_validation_batch_end(
242231 batch_idx : int ,
243232 dataloader_idx : int = 0 ,
244233 ) -> None :
234+ """Called at the end of each validation batch."""
245235 get_onelogger_callbacks ("on_validation_single_iteration_end" )
246236
247237 def on_train_end (self , trainer : Trainer , pl_module : LightningModule ) -> None :
@@ -268,23 +258,25 @@ def hook_class_init_with_callbacks(cls, start_callback: str, end_callback: str)
268258 if getattr (original_init , '_one_logger_wrapped' , False ):
269259 return
270260
271- tracker = OneLoggerTimingTracker .get_instance ()
272-
273261 @functools .wraps (original_init )
274262 def wrapped_init (self , * args , ** kwargs ):
275263 # Check if this instance has already been initialized to prevent duplicate callbacks
276264 # in inheritance chains
277- if hasattr (self , '_one_logger_init_started' ):
278- # This instance is already being initialized, skip the callbacks
265+ if hasattr (self , '_one_logger_initialized' ):
279266 return original_init (self , * args , ** kwargs )
280267
281- # Mark this instance as being initialized
282- self . _one_logger_init_started = True
268+ # Call start callback
269+ get_onelogger_callbacks ( start_callback )
283270
284- print ("NeMo CB: wrapped_init for class" , cls .__name__ )
285- tracker .track_event (start_callback )
271+ # Call original __init__
286272 result = original_init (self , * args , ** kwargs )
287- tracker .track_event (end_callback )
273+
274+ # Mark as initialized to prevent duplicate callbacks
275+ self ._one_logger_initialized = True
276+
277+ # Call end callback
278+ get_onelogger_callbacks (end_callback )
279+
288280 return result
289281
290282 # Mark as wrapped to prevent double wrapping
@@ -293,36 +285,21 @@ def wrapped_init(self, *args, **kwargs):
293285
294286
295287def init_one_logger (v1_config : Dict [str , Any ], trainer : Trainer = None , enable_onelogger : bool = True ):
296- """Initialize OneLogger with v1 config and optionally add callback to trainer .
288+ """Initialize OneLogger with v1-style configuration .
297289
298290 Args:
299291 v1_config: V1-style configuration dictionary
300- trainer: Optional PyTorch Lightning trainer to add callback to
301- enable_onelogger: Whether to enable OneLogger (default: True)
292+ trainer: Optional PyTorch Lightning trainer
293+ enable_onelogger: Whether to enable OneLogger
302294 """
303- if not HAVE_ONELOGGER or not enable_onelogger :
295+ if not enable_onelogger or not HAVE_ONELOGGER :
304296 return
305297
306- # Convert v1 config to v2 config using the adapter
307- training_telemetry_config , wandb_config = ConfigAdapter .convert_to_v2_config (v1_config )
308-
309- # Configure OneLogger using v1 adapter with async wandb exporter
310- exporter = V1CompatibleWandbExporterAsync (
311- training_telemetry_config = training_telemetry_config ,
312- wandb_config = wandb_config ,
313- )
314- TrainingTelemetryProvider .instance ().with_base_telemetry_config (training_telemetry_config ).with_exporter (
315- exporter
316- ).configure_provider ()
317-
298+ # Mark OneLogger as available
318299 OneLoggerTimingTracker .mark_one_logger_available ()
319300
320- # Add the OneLogger callback to the trainer if provided
301+ # Initialize OneLogger with v1 config
321302 if trainer is not None :
322- # Check if OneLoggerNeMoCallback is already in the trainer's callbacks
323- has_onelogger_callback = any (isinstance (callback , OneLoggerNeMoCallback ) for callback in trainer .callbacks )
324-
325- if not has_onelogger_callback :
326- # Create the callback with metadata
327- onelogger_callback = OneLoggerNeMoCallback ()
328- trainer .callbacks .append (onelogger_callback )
303+ # Add OneLogger callback to trainer
304+ callback = OneLoggerNeMoCallback ()
305+ trainer .callbacks .append (callback )
0 commit comments