Skip to content

Commit 9b4c363

Browse files
committed
[BUG] fixed memory leak in BaseModel by detach some tensor
1 parent 81b5303 commit 9b4c363

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

pytorch_forecasting/models/base/_base_model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@
5454
apply_to_list,
5555
concat_sequences,
5656
create_mask,
57+
detach,
5758
get_embedding_size,
5859
groupby_apply,
60+
move_to_device,
5961
to_list,
6062
)
6163
from pytorch_forecasting.utils._classproperty import classproperty
@@ -308,6 +310,8 @@ def on_predict_batch_end(
308310
else:
309311
raise ValueError(f"Unknown mode {self.mode} - see docs for valid arguments")
310312

313+
out = move_to_device(detach(out), "cpu")
314+
x = move_to_device(detach(x), "cpu")
311315
self._output.append(out)
312316
out = dict(output=out)
313317
if self.return_x:
@@ -720,7 +724,7 @@ def training_step(self, batch, batch_idx):
720724
"""
721725
x, y = batch
722726
log, out = self.step(x, y, batch_idx)
723-
self.training_step_outputs.append(log)
727+
self.training_step_outputs.append(detach(log))
724728
return log
725729

726730
def on_train_epoch_end(self):
@@ -739,7 +743,7 @@ def validation_step(self, batch, batch_idx):
739743
x, y = batch
740744
log, out = self.step(x, y, batch_idx)
741745
log.update(self.create_log(x, y, out, batch_idx))
742-
self.validation_step_outputs.append(log)
746+
self.validation_step_outputs.append(detach(log))
743747
return log
744748

745749
def on_validation_epoch_end(self):
@@ -750,7 +754,7 @@ def test_step(self, batch, batch_idx):
750754
x, y = batch
751755
log, out = self.step(x, y, batch_idx)
752756
log.update(self.create_log(x, y, out, batch_idx))
753-
self.testing_step_outputs.append(log)
757+
self.testing_step_outputs.append(detach(log))
754758
return log
755759

756760
def on_test_epoch_end(self):
@@ -934,7 +938,7 @@ def step(
934938
loss.requires_grad_(True)
935939
self.log(
936940
f"{self.current_stage}_loss",
937-
loss,
941+
detach(loss),
938942
on_step=self.training,
939943
on_epoch=True,
940944
prog_bar=True,

0 commit comments

Comments
 (0)