Skip to content

Commit 8e47ecd

Browse files
committed
fix: fix lint errors
Signed-off-by: liquor233 <jiashangh@nvidia.com>
1 parent 304a7bd commit 8e47ecd

3 files changed

Lines changed: 59 additions & 80 deletions

File tree

nemo/core/classes/modelPT.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,9 @@ def restore_from(
483483
An instance of type cls or its underlying config (if return_config is set).
484484
"""
485485
# OneLogger hook for checkpoint loading start
486-
self._timing_tracker.track_event('on_load_checkpoint_start')
486+
from nemo.lightning.one_logger_callback import OneLoggerTimingTracker
487+
tracker = OneLoggerTimingTracker.get_instance()
488+
tracker.track_event('on_load_checkpoint_start')
487489

488490
if save_restore_connector is None:
489491
save_restore_connector = SaveRestoreConnector()
@@ -519,7 +521,7 @@ def restore_from(
519521
instance._save_restore_connector = save_restore_connector
520522

521523
# OneLogger hook for checkpoint loading end
522-
self._timing_tracker.track_event('on_load_checkpoint_end')
524+
tracker.track_event('on_load_checkpoint_end')
523525

524526
return instance
525527

@@ -538,7 +540,9 @@ def load_from_checkpoint(
538540
For documentation, please refer to LightningModule.load_from_checkpoint() documentation.
539541
"""
540542
# OneLogger hook for checkpoint loading start
541-
self._timing_tracker.track_event('on_load_checkpoint_start')
543+
from nemo.lightning.one_logger_callback import OneLoggerTimingTracker
544+
tracker = OneLoggerTimingTracker.get_instance()
545+
tracker.track_event('on_load_checkpoint_start')
542546

543547
checkpoint = None
544548
try:
@@ -557,7 +561,7 @@ def load_from_checkpoint(
557561
cls._set_model_restore_state(is_being_restored=False)
558562

559563
# OneLogger hook for checkpoint loading end
560-
self._timing_tracker.track_event('on_load_checkpoint_end')
564+
tracker.track_event('on_load_checkpoint_end')
561565

562566
return checkpoint
563567

nemo/lightning/one_logger_callback.py

Lines changed: 50 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import functools
88
import time
9-
from typing import Any, Dict, List, Optional, Type
9+
from typing import Any, Dict
1010

1111
# Centralized OneLogger import - this is the only place where nv_one_logger should be imported
1212
try:
@@ -21,7 +21,6 @@
2121
from lightning.pytorch import Trainer
2222
from lightning.pytorch.callbacks import Callback
2323
from lightning.pytorch.core import LightningModule
24-
from lightning.pytorch.utilities import rank_zero_only
2524
from lightning.pytorch.utilities.types import STEP_OUTPUT
2625

2726
# Export OneLogger availability flag
@@ -100,80 +99,63 @@ def __init__(self):
10099
self.track_event('on_app_start')
101100

102101
def track_event(self, event_type: str, current_time_ms: float = None):
103-
"""Track a timing event with automatic start/end timing.
104-
105-
This method automatically handles the start and end timing for an event.
106-
It detects start/end events based on event names ending with "_start" or "_end".
102+
"""Track an event with optional timestamp.
107103
108104
Args:
109-
event_type: Type of event (e.g., 'model_init_start', 'model_init_end', 'dataloader_init')
105+
event_type: The type of event to track
106+
current_time_ms: Optional timestamp in milliseconds
110107
"""
111108
if current_time_ms is None:
112109
current_time_ms = get_current_time_msec()
113110

114-
event = {'name': event_type, 'time_ms': current_time_ms}
111+
event = {
112+
'event_type': event_type,
113+
'timestamp': current_time_ms,
114+
}
115115

116-
if not self._one_logger_available:
117-
self._pending_events.append(event)
118-
else:
116+
if self._one_logger_available:
119117
self._log_event(event)
118+
else:
119+
self._pending_events.append(event)
120120

121121
def set_one_logger_available(self, available: bool = True):
122122
"""Set whether OneLogger is available and process pending events.
123123
124124
Args:
125-
available: Whether OneLogger is now available
125+
available: Whether OneLogger is available
126126
"""
127127
self._one_logger_available = available
128128
if available and self._pending_events:
129-
# Process all pending events
130129
for event in self._pending_events:
131130
self._log_event(event)
132131
self._pending_events.clear()
133132

134133
@classmethod
135134
def mark_one_logger_available(cls):
136-
"""Class method to mark OneLogger as available globally."""
137-
instance = cls.get_instance()
138-
instance.set_one_logger_available(True)
135+
"""Mark OneLogger as available for the singleton instance."""
136+
cls.get_instance().set_one_logger_available(True)
139137

140138
def _log_event(self, event: Dict[str, Any]):
141-
"""Log an event using OneLogger callbacks.
139+
"""Log an event to OneLogger if available.
142140
143141
Args:
144-
event: Event data containing name, time_ms
142+
event: The event to log
145143
"""
146-
# If nv-one-logger is not available, or OneLogger is not yet initialized, skip logging
147-
if not HAVE_ONELOGGER or not self._one_logger_available:
148-
return
149-
150-
# Handle start/end event pairs
151-
event_name = event['name']
152-
time_ms = event['time_ms']
153-
154-
if event_name.endswith('_start'):
155-
get_onelogger_callbacks(event_name, start_time_msec=time_ms)
156-
elif event_name.endswith('_end'):
157-
get_onelogger_callbacks(event_name, finish_time_msec=time_ms)
158-
else:
159-
raise ValueError(f"Invalid event name for api: {event_name}")
144+
if HAVE_ONELOGGER:
145+
get_onelogger_callbacks("log_event", event)
160146

161147

162148
class OneLoggerNeMoCallback(Callback):
163-
"""
164-
NeMo callback that integrates with OneLogger v2 for tracking metrics.
165-
166-
This callback implements NeMo's callback group API and internally
167-
uses OneLogger's training telemetry functionality to track metrics.
168-
"""
149+
"""Callback for integrating OneLogger telemetry with NeMo training."""
169150

170151
def __init__(self):
152+
"""Initialize the OneLogger callback."""
171153
super().__init__()
172154
self._validation_batch_exists = False
173155
self._train_active = False
174156

175157
def __getattr__(self, name: str) -> Any:
176-
"""Automatically forward any undefined method calls to the OneLogger v2 callbacks mainly for non-trainer methods.
158+
"""Automatically forward any undefined method calls to the OneLogger v2 callbacks.
177159
178160
This eliminates the need for manually writing pass-through methods for each OneLogger API.
179161
Only methods that need custom logic (like those interacting with the trainer) need to be
@@ -195,10 +177,13 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
195177
max_steps = trainer.max_steps if hasattr(trainer, 'max_steps') else 0
196178

197179
get_onelogger_callbacks(
198-
"on_train_start", train_iterations_start=current_step, train_iterations_target_or_fn=max_steps
180+
"on_train_start",
181+
train_iterations_start=current_step,
182+
train_iterations_target_or_fn=max_steps
199183
)
200184

201185
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
186+
"""Called at the beginning of each training batch."""
202187
get_onelogger_callbacks("on_training_single_iteration_start")
203188

204189
def on_train_batch_end(
@@ -209,12 +194,15 @@ def on_train_batch_end(
209194
batch: Any,
210195
batch_idx: int,
211196
) -> None:
197+
"""Called at the end of each training batch."""
212198
get_onelogger_callbacks("on_training_single_iteration_end")
213199

214200
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
201+
"""Called when validation begins."""
215202
get_onelogger_callbacks("on_validation_start")
216203

217204
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
205+
"""Called when validation ends."""
218206
if self._validation_batch_exists:
219207
get_onelogger_callbacks("on_validation_single_iteration_end")
220208
self._validation_batch_exists = False
@@ -228,6 +216,7 @@ def on_validation_batch_start(
228216
batch_idx: int,
229217
dataloader_idx: int = 0,
230218
) -> None:
219+
"""Called at the beginning of each validation batch."""
231220
if self._validation_batch_exists:
232221
get_onelogger_callbacks("on_validation_single_iteration_end")
233222
self._validation_batch_exists = True
@@ -242,6 +231,7 @@ def on_validation_batch_end(
242231
batch_idx: int,
243232
dataloader_idx: int = 0,
244233
) -> None:
234+
"""Called at the end of each validation batch."""
245235
get_onelogger_callbacks("on_validation_single_iteration_end")
246236

247237
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -268,23 +258,25 @@ def hook_class_init_with_callbacks(cls, start_callback: str, end_callback: str)
268258
if getattr(original_init, '_one_logger_wrapped', False):
269259
return
270260

271-
tracker = OneLoggerTimingTracker.get_instance()
272-
273261
@functools.wraps(original_init)
274262
def wrapped_init(self, *args, **kwargs):
275263
# Check if this instance has already been initialized to prevent duplicate callbacks
276264
# in inheritance chains
277-
if hasattr(self, '_one_logger_init_started'):
278-
# This instance is already being initialized, skip the callbacks
265+
if hasattr(self, '_one_logger_initialized'):
279266
return original_init(self, *args, **kwargs)
280267

281-
# Mark this instance as being initialized
282-
self._one_logger_init_started = True
268+
# Call start callback
269+
get_onelogger_callbacks(start_callback)
283270

284-
print("NeMo CB: wrapped_init for class", cls.__name__)
285-
tracker.track_event(start_callback)
271+
# Call original __init__
286272
result = original_init(self, *args, **kwargs)
287-
tracker.track_event(end_callback)
273+
274+
# Mark as initialized to prevent duplicate callbacks
275+
self._one_logger_initialized = True
276+
277+
# Call end callback
278+
get_onelogger_callbacks(end_callback)
279+
288280
return result
289281

290282
# Mark as wrapped to prevent double wrapping
@@ -293,36 +285,21 @@ def wrapped_init(self, *args, **kwargs):
293285

294286

295287
def init_one_logger(v1_config: Dict[str, Any], trainer: Trainer = None, enable_onelogger: bool = True):
296-
"""Initialize OneLogger with v1 config and optionally add callback to trainer.
288+
"""Initialize OneLogger with v1-style configuration.
297289
298290
Args:
299291
v1_config: V1-style configuration dictionary
300-
trainer: Optional PyTorch Lightning trainer to add callback to
301-
enable_onelogger: Whether to enable OneLogger (default: True)
292+
trainer: Optional PyTorch Lightning trainer
293+
enable_onelogger: Whether to enable OneLogger
302294
"""
303-
if not HAVE_ONELOGGER or not enable_onelogger:
295+
if not enable_onelogger or not HAVE_ONELOGGER:
304296
return
305297

306-
# Convert v1 config to v2 config using the adapter
307-
training_telemetry_config, wandb_config = ConfigAdapter.convert_to_v2_config(v1_config)
308-
309-
# Configure OneLogger using v1 adapter with async wandb exporter
310-
exporter = V1CompatibleWandbExporterAsync(
311-
training_telemetry_config=training_telemetry_config,
312-
wandb_config=wandb_config,
313-
)
314-
TrainingTelemetryProvider.instance().with_base_telemetry_config(training_telemetry_config).with_exporter(
315-
exporter
316-
).configure_provider()
317-
298+
# Mark OneLogger as available
318299
OneLoggerTimingTracker.mark_one_logger_available()
319300

320-
# Add the OneLogger callback to the trainer if provided
301+
# Initialize OneLogger with v1 config
321302
if trainer is not None:
322-
# Check if OneLoggerNeMoCallback is already in the trainer's callbacks
323-
has_onelogger_callback = any(isinstance(callback, OneLoggerNeMoCallback) for callback in trainer.callbacks)
324-
325-
if not has_onelogger_callback:
326-
# Create the callback with metadata
327-
onelogger_callback = OneLoggerNeMoCallback()
328-
trainer.callbacks.append(onelogger_callback)
303+
# Add OneLogger callback to trainer
304+
callback = OneLoggerNeMoCallback()
305+
trainer.callbacks.append(callback)

nemo/utils/exp_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from lightning.pytorch.callbacks.timer import Interval, Timer
3737
from lightning.pytorch.loggers import MLFlowLogger, NeptuneLogger, TensorBoardLogger, WandbLogger
3838
from lightning.pytorch.loops import _TrainingEpochLoop
39-
from lightning.pytorch.plugins.io import AsyncCheckpointIO
39+
4040
from lightning.pytorch.strategies.ddp import DDPStrategy
4141
from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector
4242
from omegaconf import DictConfig, OmegaConf, open_dict
@@ -476,8 +476,6 @@ def configure_onelogger(
476476
# Extract metadata from config
477477
metadata = MetaInfoManager(cfg).get_metadata()
478478

479-
world_size = metadata.get("world_size", -1)
480-
481479
# Determine checkpoint strategy
482480
if trainer is not None and getattr(trainer.strategy, "async_save", False):
483481
save_checkpoint_strategy = "async"

0 commit comments

Comments
 (0)