Skip to content

Commit 9c9c0ff

Browse files
committed
FIX: accumulation should not effect value of loss
1 parent 98bc785 commit 9c9c0ff

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

ignite/engine/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
102102
y_pred = model(x)
103103
loss = loss_fn(y_pred, y)
104104
if gradient_accumulation_steps > 1:
105-
loss = loss / gradient_accumulation_steps
105+
loss = loss / gradient_accumulation_steps # fix this
106106
loss.backward()
107107
if engine.state.iteration % gradient_accumulation_steps == 0:
108108
optimizer.step()
109-
return output_transform(x, y, y_pred, loss)
109+
return output_transform(x, y, y_pred, loss * gradient_accumulation_steps)
110110

111111
return update
112112

@@ -192,7 +192,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
192192
loss.backward()
193193
if engine.state.iteration % gradient_accumulation_steps == 0:
194194
optimizer.step()
195-
return output_transform(x, y, y_pred, loss)
195+
return output_transform(x, y, y_pred, loss * gradient_accumulation_steps)
196196

197197
return update
198198

@@ -269,7 +269,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
269269
scaled_loss.backward()
270270
if engine.state.iteration % gradient_accumulation_steps == 0:
271271
optimizer.step()
272-
return output_transform(x, y, y_pred, loss)
272+
return output_transform(x, y, y_pred, loss * gradient_accumulation_steps)
273273

274274
return update
275275

@@ -340,7 +340,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
340340
y_pred = model(x)
341341
loss = loss_fn(y_pred, y)
342342
if gradient_accumulation_steps > 1:
343-
loss = loss / gradient_accumulation_steps
343+
loss = loss / gradient_accumulation_steps # fix this
344344
loss.backward()
345345
if engine.state.iteration % gradient_accumulation_steps == 0:
346346
xm.optimizer_step(optimizer, barrier=True)

tests/ignite/engine/test_create_supervised.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def _():
9797
_x, _y = trainer.state.batch
9898
_x, _y = _x.to(model_device), _y.to(model_device)
9999
accumulation[0] += 0.2 * _x.item() * (theta[0] * _x.item() - _y.item())
100-
# loss is not accumulated !
101-
loss[0] = mse_loss(model(_x), _y).item() / gradient_accumulation_steps
100+
# value of loss should not be accumulated
101+
loss[0] = mse_loss(model(_x), _y).item()
102102

103103
@trainer.on(Events.ITERATION_COMPLETED(every=gradient_accumulation_steps))
104104
def _():

0 commit comments

Comments
 (0)