@@ -102,11 +102,11 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
102102 y_pred = model (x )
103103 loss = loss_fn (y_pred , y )
104104 if gradient_accumulation_steps > 1 :
105- loss = loss / gradient_accumulation_steps
105+ loss = loss / gradient_accumulation_steps # fix this
106106 loss .backward ()
107107 if engine .state .iteration % gradient_accumulation_steps == 0 :
108108 optimizer .step ()
109- return output_transform (x , y , y_pred , loss )
109+ return output_transform (x , y , y_pred , loss * gradient_accumulation_steps )
110110
111111 return update
112112
@@ -192,7 +192,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
192192 loss .backward ()
193193 if engine .state .iteration % gradient_accumulation_steps == 0 :
194194 optimizer .step ()
195- return output_transform (x , y , y_pred , loss )
195+ return output_transform (x , y , y_pred , loss * gradient_accumulation_steps )
196196
197197 return update
198198
@@ -269,7 +269,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
269269 scaled_loss .backward ()
270270 if engine .state .iteration % gradient_accumulation_steps == 0 :
271271 optimizer .step ()
272- return output_transform (x , y , y_pred , loss )
272+ return output_transform (x , y , y_pred , loss * gradient_accumulation_steps )
273273
274274 return update
275275
@@ -340,7 +340,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
340340 y_pred = model (x )
341341 loss = loss_fn (y_pred , y )
342342 if gradient_accumulation_steps > 1 :
343- loss = loss / gradient_accumulation_steps
343+ loss = loss / gradient_accumulation_steps # fix this
344344 loss .backward ()
345345 if engine .state .iteration % gradient_accumulation_steps == 0 :
346346 xm .optimizer_step (optimizer , barrier = True )
0 commit comments