@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
231231 while_grad->SetInput (kStepScopes , Output (kStepScopes ));
232232
233233 auto *grad_block = this ->grad_block_ [0 ];
234- auto *fwd_block = grad_block->ParentBlock ();
234+ auto *fwd_block = grad_block->ForwardBlock ();
235+ auto *parent_block = grad_block->ParentBlock ();
235236
236237 // Not all of IGs will be generated by inner gradient operators of while op.
237238 // Ignore IGs that is not generated by the inside block.
@@ -260,33 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
260261 for (auto &o : Output (kOutputs )) {
261262 block_ins.insert (o);
262263 }
263- std::unordered_set<std::string> extra_inputs ;
264+ std::unordered_set<std::string> output_grads ;
264265 for (const auto *op : grad_block->AllOps ()) {
265266 for (auto &input_name : op->InputArgumentNames ()) {
266267 // If the input of Op has been recorded or is generated by the forward
267268 // block, do not make it as input again.
269+
270+ // The input is located in I/O or other op's outputs or the variable is
271+ // located in grad_block's parents
268272 if (block_ins.find (input_name) != block_ins.end () ||
269- fwd_block->FindVar (input_name) != nullptr ) {
273+ (fwd_block->FindVarRecursive (input_name) != nullptr ||
274+ parent_block->FindVarRecursive (input_name) != nullptr )) {
270275 continue ;
271276 }
272- extra_inputs .insert (input_name);
277+ output_grads .insert (input_name);
273278 }
274279 for (auto &output_name : op->OutputArgumentNames ()) {
275280 block_ins.insert (output_name);
276281 }
277282 }
278283
279- std::vector<std::string> extra_inputs_list ;
280- extra_inputs_list .resize (extra_inputs .size ());
281- std::copy (extra_inputs .begin (), extra_inputs .end (),
282- extra_inputs_list .begin ());
283- while_grad->SetInput (framework::GradVarName (kOutputs ), extra_inputs_list );
284+ std::vector<std::string> output_grads_list ;
285+ output_grads_list .resize (output_grads .size ());
286+ std::copy (output_grads .begin (), output_grads .end (),
287+ output_grads_list .begin ());
288+ while_grad->SetInput (framework::GradVarName (kOutputs ), output_grads_list );
284289
285290 while_grad->SetAttrMap (this ->Attrs ());
286291 while_grad->SetBlockAttr (kStepBlock , *grad_block);
287292 // record the original output gradient names, since the gradient name of
288293 // while operator could be renamed.
289- while_grad->SetAttr (" original_output_grad" , extra_inputs_list );
294+ while_grad->SetAttr (" original_output_grad" , output_grads_list );
290295
291296 return std::unique_ptr<framework::OpDesc>(while_grad);
292297 }
0 commit comments