@@ -217,18 +217,20 @@ RunCustomOpNode::operator()(
217217 VLOG (6 ) << " Prepare Grad outputs for size: " << grad_outputs_names.size ();
218218 for (size_t i = 0 ; i < OutputMeta ().size (); i++) {
219219 if (map[0 ][0 ].find (i) != map[0 ][0 ].end ()) {
220+ int grad_output_idx = map[0 ][0 ][i];
220221 VLOG (7 ) << " Insert grad outputs: " << i
221- << " with size: " << OutputMeta ()[i].size ()
222- << " to tmp_outputs: " << map[0 ][0 ][i];
223- for (size_t j = 0 ; j < OutputMeta ()[i].size (); j++) {
224- outs[i].emplace_back (/* init it incase of copy nullptr of shared_ptr */
225- std::make_shared<phi::DenseTensor>(
226- phi::DataType::UNDEFINED),
227- egr::Controller::Instance ().GenerateUniqueName (
228- " custom_tmp_grad" ));
229- egr::EagerUtils::autograd_meta (&(outs[i][j]));
222+ << " with size: " << OutputMeta ()[grad_output_idx].size ()
223+ << " to tmp_outputs: " << grad_output_idx;
224+ for (size_t j = 0 ; j < OutputMeta ()[grad_output_idx].size (); j++) {
225+ outs[grad_output_idx]
226+ .emplace_back (/* init it incase of copy nullptr of shared_ptr */
227+ std::make_shared<phi::DenseTensor>(
228+ phi::DataType::UNDEFINED),
229+ egr::Controller::Instance ().GenerateUniqueName (
230+ " custom_tmp_grad" ));
231+ egr::EagerUtils::autograd_meta (&(outs[grad_output_idx][j]));
230232 }
231- tmp_outs[map[ 0 ][ 0 ][i]] = outs[i ];
233+ tmp_outs[grad_output_idx] = outs[grad_output_idx ];
232234 }
233235 }
234236 for (size_t i = 0 ; i < tmp_outs.size (); i++) {
0 commit comments