@@ -453,6 +453,12 @@ def _run(self, inputs, labels=None):
453453 if len (name ) > 0 :
454454 rets .insert (i , feed [name ])
455455
456+ # step learning rate scheduler on each batch end
457+ if self .model ._optimizer and \
458+ isinstance (self .model ._optimizer ._learning_rate ,
459+ paddle .optimizer .lr .LRScheduler ):
460+ self .model ._optimizer ._learning_rate .step ()
461+
456462 # LoDTensor cannot be fetch as numpy directly
457463 rets = [np .array (v ) for v in rets ]
458464 if self .mode == 'test' :
@@ -652,6 +658,13 @@ def train_batch(self, inputs, labels=None):
652658
653659 self .model ._optimizer .minimize (final_loss )
654660 self .model .network .clear_gradients ()
661+
662+ # step learning rate scheduler on each batch end
663+ if self .model ._optimizer and \
664+ isinstance (self .model ._optimizer ._learning_rate ,
665+ paddle .optimizer .lr .LRScheduler ):
666+ self .model ._optimizer ._learning_rate .step ()
667+
655668 metrics = []
656669 for metric in self .model ._metrics :
657670 metric_outs = metric .compute (* (to_list (outputs ) + labels ))
@@ -1461,11 +1474,6 @@ def fit(
14611474
14621475 cbks .on_end ('eval' , eval_logs )
14631476
1464- # step learning rate scheduler on each epcoh end
1465- if isinstance (self ._optimizer ._learning_rate ,
1466- paddle .optimizer .lr .LRScheduler ):
1467- self ._optimizer ._learning_rate .step ()
1468-
14691477 cbks .on_end ('train' , logs )
14701478 self ._test_dataloader = None
14711479
0 commit comments