5454 apply_to_list ,
5555 concat_sequences ,
5656 create_mask ,
57+ detach ,
5758 get_embedding_size ,
5859 groupby_apply ,
60+ move_to_device ,
5961 to_list ,
6062)
6163from pytorch_forecasting .utils ._classproperty import classproperty
@@ -308,6 +310,8 @@ def on_predict_batch_end(
308310 else :
309311 raise ValueError (f"Unknown mode { self .mode } - see docs for valid arguments" )
310312
313+ out = move_to_device (detach (out ), "cpu" )
314+ x = move_to_device (detach (x ), "cpu" )
311315 self ._output .append (out )
312316 out = dict (output = out )
313317 if self .return_x :
@@ -720,7 +724,7 @@ def training_step(self, batch, batch_idx):
720724 """
721725 x , y = batch
722726 log , out = self .step (x , y , batch_idx )
723- self .training_step_outputs .append (log )
727+ self .training_step_outputs .append (detach ( log ) )
724728 return log
725729
726730 def on_train_epoch_end (self ):
@@ -739,7 +743,7 @@ def validation_step(self, batch, batch_idx):
739743 x , y = batch
740744 log , out = self .step (x , y , batch_idx )
741745 log .update (self .create_log (x , y , out , batch_idx ))
742- self .validation_step_outputs .append (log )
746+ self .validation_step_outputs .append (detach ( log ) )
743747 return log
744748
745749 def on_validation_epoch_end (self ):
@@ -750,7 +754,7 @@ def test_step(self, batch, batch_idx):
750754 x , y = batch
751755 log , out = self .step (x , y , batch_idx )
752756 log .update (self .create_log (x , y , out , batch_idx ))
753- self .testing_step_outputs .append (log )
757+ self .testing_step_outputs .append (detach ( log ) )
754758 return log
755759
756760 def on_test_epoch_end (self ):
@@ -934,7 +938,7 @@ def step(
934938 loss .requires_grad_ (True )
935939 self .log (
936940 f"{ self .current_stage } _loss" ,
937- loss ,
941+ detach ( loss ) ,
938942 on_step = self .training ,
939943 on_epoch = True ,
940944 prog_bar = True ,
0 commit comments