@@ -315,9 +315,7 @@ def _create_pure_fp16_program(self, is_infer_mode=False):
315315 def _create_forward_backward_train_program (self ):
316316 whole_program = self ._train_program
317317 # _, forward_end_op_index = self._infer_info('fp32', self._create_program)
318- forward_end_op_index = self ._forward_end_index_map [
319- _hash_with_id (whole_program , self )
320- ]
318+ forward_end_op_index = self .get_forward_end_op_idx (whole_program )
321319 assert forward_end_op_index >= 0
322320
323321 return self ._get_forward_backward_program_form (
@@ -438,11 +436,14 @@ def _infer_pure_fp16_program_id(self):
438436 def _param_grad_names (self ):
439437 return _param_grad_names (self ._train_program .desc , self ._params )
440438
439+ def get_forward_end_op_idx (self , program ):
440+ return self ._forward_end_index_map [_hash_with_id (program , self )]
441+
441442 @LazyInitialized
442443 def _out_grad_names (self ):
443444 return _out_grad_names (
444445 self ._train_program .desc ,
445- self ._create_program ( is_infer_mode = True ). desc . block ( 0 ). op_size ( ),
446+ self .get_forward_end_op_idx ( self . _train_program ),
446447 len (self ._outputs .var_ids ),
447448 )
448449
@@ -642,6 +643,7 @@ def _append_backward_desc(self, main_program):
642643 if isinstance (out , framework .Variable ):
643644 targets .append (program .global_block ().var (out .name ))
644645
646+ start_idx = len (program .block (0 ).ops ) + len (self ._outputs .tolist ())
645647 if targets :
646648 # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
647649 core .check_and_set_prim_all_enabled ()
@@ -652,12 +654,11 @@ def _append_backward_desc(self, main_program):
652654 program , start_idx = self ._hooker .after_append_backward (
653655 self , program , start_idx
654656 )
655- self ._forward_end_index_map [
656- _hash_with_id (program , self )
657- ] = start_idx - len (self ._outputs .tolist ())
658- # TODO: prim make this complicate
659657 self .prepare_gradient_aggregation (start_idx , main_program , program )
660658
659+ self ._forward_end_index_map [
660+ _hash_with_id (program , self )
661+ ] = start_idx - len (self ._outputs .tolist ())
661662 return program
662663
663664 def _prune_unused_params (self , program ):
@@ -1155,5 +1156,8 @@ def add_build_strategy_for(
11551156 if hasattr (compiled_program ._program , 'lr_sheduler' ):
11561157 builded_program .lr_sheduler = compiled_program ._program .lr_sheduler
11571158 else :
1159+ # can't just create a new program, we need copy the vardesc.
11581160 builded_program = paddle .static .Program ()
1161+ for var in program .block (0 ).vars .values ():
1162+ builded_program .block (0 )._clone_variable (var , False )
11591163 return builded_program
0 commit comments