@@ -720,7 +720,10 @@ def training_step(self, batch, batch_idx):
720720 """
721721 x , y = batch
722722 log , out = self .step (x , y , batch_idx )
723- self .training_step_outputs .append (log )
723+ detached_log = {
724+ k : v .detach () if isinstance (v , torch .Tensor ) else v for k , v in log .items ()
725+ }
726+ self .training_step_outputs .append (detached_log )
724727 return log
725728
726729 def on_train_epoch_end (self ):
@@ -739,7 +742,10 @@ def validation_step(self, batch, batch_idx):
739742 x , y = batch
740743 log , out = self .step (x , y , batch_idx )
741744 log .update (self .create_log (x , y , out , batch_idx ))
742- self .validation_step_outputs .append (log )
745+ detached_log = {
746+ k : v .detach () if isinstance (v , torch .Tensor ) else v for k , v in log .items ()
747+ }
748+ self .validation_step_outputs .append (detached_log )
743749 return log
744750
745751 def on_validation_epoch_end (self ):
@@ -750,7 +756,10 @@ def test_step(self, batch, batch_idx):
750756 x , y = batch
751757 log , out = self .step (x , y , batch_idx )
752758 log .update (self .create_log (x , y , out , batch_idx ))
753- self .testing_step_outputs .append (log )
759+ detached_log = {
760+ k : v .detach () if isinstance (v , torch .Tensor ) else v for k , v in log .items ()
761+ }
762+ self .testing_step_outputs .append (detached_log )
754763 return log
755764
756765 def on_test_epoch_end (self ):
@@ -934,7 +943,7 @@ def step(
934943 loss .requires_grad_ (True )
935944 self .log (
936945 f"{ self .current_stage } _loss" ,
937- loss ,
946+ loss . detach () ,
938947 on_step = self .training ,
939948 on_epoch = True ,
940949 prog_bar = True ,
0 commit comments