diff --git a/composer/callbacks/speed_monitor.py b/composer/callbacks/speed_monitor.py index 047f6e8251..a24463259b 100644 --- a/composer/callbacks/speed_monitor.py +++ b/composer/callbacks/speed_monitor.py @@ -258,6 +258,7 @@ def __init__( self.history_flops: deque[float] = deque(maxlen=window_size + 1) self.gpu_flops_available = gpu_flops_available + self.time_unit = time_unit self.divider = 1 if time_unit == 'seconds': @@ -360,10 +361,18 @@ def batch_end(self, state: State, logger: Logger): # Log the time # `state.timestamp` excludes any time spent in evaluation train_wct = state.timestamp.total_wct.total_seconds() + secs_per_step = 0 + if len(self.history_wct) > 1: + secs_per_step = self.history_wct[-1] - self.history_wct[-2] + logger.log_metrics({ 'time/train': train_wct / self.divider, 'time/val': self.total_eval_wct / self.divider, 'time/total': (train_wct + self.total_eval_wct) / self.divider, + f'time_{self.time_unit}/train': train_wct / self.divider, + f'time_{self.time_unit}/val': self.total_eval_wct / self.divider, + f'time_{self.time_unit}/total': (train_wct + self.total_eval_wct) / self.divider, + 'time/secs_per_step': secs_per_step, }) def eval_end(self, state: State, logger: Logger):