@@ -261,35 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
261261 for (auto &o : Output (kOutputs )) {
262262 block_ins.insert (o);
263263 }
264- std::unordered_set<std::string> extra_inputs ;
264+ std::unordered_set<std::string> output_grads ;
265265 for (const auto *op : grad_block->AllOps ()) {
266266 for (auto &input_name : op->InputArgumentNames ()) {
267267 // If the input of Op has been recorded or is generated by the forward
268268 // block, do not make it as input again.
269269
270+ // The input is located in I/O or other op's outputs or the variable is
271+ // located in grad_block's parents
270272 if (block_ins.find (input_name) != block_ins.end () ||
271- fwd_block->FindVar (input_name) != nullptr ||
272- parent_block->FindVar (input_name) != nullptr ) {
273+ ( fwd_block->FindVarRecursive (input_name) != nullptr ||
274+ parent_block->FindVarRecursive (input_name) != nullptr ) ) {
273275 continue ;
274276 }
275- extra_inputs .insert (input_name);
277+ output_grads .insert (input_name);
276278 }
277279 for (auto &output_name : op->OutputArgumentNames ()) {
278280 block_ins.insert (output_name);
279281 }
280282 }
281283
282- std::vector<std::string> extra_inputs_list ;
283- extra_inputs_list .resize (extra_inputs .size ());
284- std::copy (extra_inputs .begin (), extra_inputs .end (),
285- extra_inputs_list .begin ());
286- while_grad->SetInput (framework::GradVarName (kOutputs ), extra_inputs_list );
284+ std::vector<std::string> output_grads_list ;
285+ output_grads_list .resize (output_grads .size ());
286+ std::copy (output_grads .begin (), output_grads .end (),
287+ output_grads_list .begin ());
288+ while_grad->SetInput (framework::GradVarName (kOutputs ), output_grads_list );
287289
288290 while_grad->SetAttrMap (this ->Attrs ());
289291 while_grad->SetBlockAttr (kStepBlock , *grad_block);
290292 // record the original output gradient names, since the gradient name of
291293 // while operator could be renamed.
292- while_grad->SetAttr (" original_output_grad" , extra_inputs_list );
294+ while_grad->SetAttr (" original_output_grad" , output_grads_list );
293295
294296 return std::unique_ptr<framework::OpDesc>(while_grad);
295297 }
0 commit comments