@@ -123,7 +123,7 @@ def _create_accumulators(self, block, parameters):
123123 """
124124 pass
125125
126- def _finish_update (self , block , parameters ):
126+ def _finish_update (self , block , parameters_and_grads ):
127127 """Finish any custom updates needed
128128 before completing an optimization step
129129
@@ -226,18 +226,18 @@ def _create_optimization_pass(self,
226226
227227 optimize_ops = []
228228 for param_and_grad in parameters_and_grads :
229+ if param_and_grad [1 ] is None :
230+ continue
229231 with param_and_grad [0 ].block .program .optimized_guard (
230- param_and_grad [0 ]):
231- if param_and_grad [0 ].trainable is True and param_and_grad [
232- 1 ] is not None :
232+ param_and_grad ):
233+ if param_and_grad [0 ].trainable is True :
233234 optimize_op = self ._append_optimize_op (loss .block ,
234235 param_and_grad )
235236 optimize_ops .append (optimize_op )
236237
237238 # Get custom finish ops for subclasses
238239 # FIXME: Need to fix this once we figure out how to handle dependencies
239- self ._finish_update (loss .block ,
240- [p [0 ] for p in parameters_and_grads ])
240+ self ._finish_update (loss .block , parameters_and_grads )
241241
242242 end = len (global_block .ops )
243243 return global_block .slice_ops (start , end )
@@ -564,13 +564,15 @@ def _append_optimize_op(self, block, param_and_grad):
564564
565565 return adam_op
566566
567- def _finish_update (self , block , parameters ):
567+ def _finish_update (self , block , param_and_grads ):
568568 """Update Beta1 and Beta2 Power accumulators
569569 """
570570 assert isinstance (block , framework .Block )
571571 main_block = block .program .global_block ()
572- for param in parameters :
573- with param .block .program .optimized_guard (param ):
572+ for param , grad in param_and_grads :
573+ if grad is None :
574+ continue
575+ with param .block .program .optimized_guard ([param , grad ]):
574576 beta1_pow_acc = self ._get_accumulator (self ._beta1_pow_acc_str ,
575577 param )
576578 beta2_pow_acc = self ._get_accumulator (self ._beta2_pow_acc_str ,
@@ -691,13 +693,15 @@ def _append_optimize_op(self, block, param_and_grad):
691693
692694 return adamax_op
693695
694- def _finish_update (self , block , parameters ):
696+ def _finish_update (self , block , parameters_and_grads ):
695697 """Update Beta1 Power accumulator
696698 """
697699 assert isinstance (block , framework .Block )
698700 main_block = block .program .global_block ()
699- for param in parameters :
700- with param .block .program .optimized_guard (param ):
701+ for param , grad in parameters_and_grads :
702+ if grad is None :
703+ continue
704+ with param .block .program .optimized_guard ([param , grad ]):
701705 beta1_pow_acc = self ._get_accumulator (self ._beta1_pow_acc_str ,
702706 param )
703707 main_block .append_op (
@@ -1158,7 +1162,9 @@ def __init__(self,
11581162 self .params_grads .append ((param , grad ))
11591163
11601164 for param , grad in self .params_grads :
1161- with param .block .program .optimized_guard (param ):
1165+ if grad is None :
1166+ continue
1167+ with param .block .program .optimized_guard ([param , grad ]):
11621168 self ._append_average_accumulate_op (param )
11631169
11641170 self .apply_program = Program ()
0 commit comments