Skip to content

Commit f1d90ca

Browse files
committed
[BUG] fixed memory leak in BaseModel by detach some tensor
1 parent 2c1ed8d commit f1d90ca

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

pytorch_forecasting/models/base/_base_model.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,10 @@ def training_step(self, batch, batch_idx):
720720
"""
721721
x, y = batch
722722
log, out = self.step(x, y, batch_idx)
723-
self.training_step_outputs.append(log)
723+
detached_log = {
724+
k: v.detach() if isinstance(v, torch.Tensor) else v for k, v in log.items()
725+
}
726+
self.training_step_outputs.append(detached_log)
724727
return log
725728

726729
def on_train_epoch_end(self):
@@ -739,7 +742,10 @@ def validation_step(self, batch, batch_idx):
739742
x, y = batch
740743
log, out = self.step(x, y, batch_idx)
741744
log.update(self.create_log(x, y, out, batch_idx))
742-
self.validation_step_outputs.append(log)
745+
detached_log = {
746+
k: v.detach() if isinstance(v, torch.Tensor) else v for k, v in log.items()
747+
}
748+
self.validation_step_outputs.append(detached_log)
743749
return log
744750

745751
def on_validation_epoch_end(self):
@@ -750,7 +756,10 @@ def test_step(self, batch, batch_idx):
750756
x, y = batch
751757
log, out = self.step(x, y, batch_idx)
752758
log.update(self.create_log(x, y, out, batch_idx))
753-
self.testing_step_outputs.append(log)
759+
detached_log = {
760+
k: v.detach() if isinstance(v, torch.Tensor) else v for k, v in log.items()
761+
}
762+
self.testing_step_outputs.append(detached_log)
754763
return log
755764

756765
def on_test_epoch_end(self):
@@ -934,7 +943,7 @@ def step(
934943
loss.requires_grad_(True)
935944
self.log(
936945
f"{self.current_stage}_loss",
937-
loss,
946+
loss.detach(),
938947
on_step=self.training,
939948
on_epoch=True,
940949
prog_bar=True,

0 commit comments

Comments
 (0)