@@ -36,48 +36,73 @@ DECLARE_bool(sort_sum_gradient);
3636namespace paddle {
3737namespace imperative {
3838
39- void BasicEngine::Init (VarBase* var, bool retain_graph) {
39+ void BasicEngine::Init (
40+ const std::vector<std::shared_ptr<VarBase>>& tensors,
41+ const std::vector<std::shared_ptr<VarBase>>& grad_tensors,
42+ bool retain_graph) {
4043 retain_graph_ = retain_graph;
41- init_node_ = var->GradVarBase ()->GradNode ();
42- PADDLE_ENFORCE_EQ (var->GradVarBase ()->GraphIsFreed (), false ,
43- platform::errors::Unavailable (
44- " %s trying to backward through the same graph a second "
45- " time, but this graph have already been freed. Please "
46- " specify Tensor.backward(retain_graph=True) when "
47- " calling backward at the first time." ,
48- var->Name ()));
49-
50- if (!retain_graph) {
51- VLOG (5 ) << " Clear the auto-grad graph from grad var " << var->Name ()
52- << " because of retain_graph=False when calling backward" ;
53- var->GradVarBase ()->SetGraphIsFreed (true );
54- var->GradVarBase ()->ClearGradNode ();
55- }
5644
57- if (init_node_ == nullptr || var->OverridedStopGradient ()) {
58- VLOG (3 ) << " Skip auto grad since there is no grad op for var or loss is "
59- " stop_gradient=True: "
60- << var->Name ();
61- return ;
62- }
45+ PADDLE_ENFORCE_EQ (
46+ tensors.size (), grad_tensors.size (),
47+ platform::errors::Unavailable (
48+ " The size of tensors do not equal the size of grad_tensors,"
49+ " the size of tensors is %s, but the size of grad_tensors is %s." ,
50+ tensors.size (), grad_tensors.size ()));
51+
52+ for (size_t i = 0 ; i < tensors.size (); ++i) {
53+ auto var = tensors[i];
54+ auto grad_tensor = grad_tensors[i];
55+
56+ auto init_node = var->GradVarBase ()->GradNode ();
57+ PADDLE_ENFORCE_EQ (
58+ var->GradVarBase ()->GraphIsFreed (), false ,
59+ platform::errors::Unavailable (
60+ " %s trying to backward through the same graph a second "
61+ " time, but this graph have already been freed. Please "
62+ " specify Tensor.backward(retain_graph=True) when "
63+ " calling backward at the first time." ,
64+ var->Name ()));
65+
66+ if (!retain_graph) {
67+ VLOG (5 ) << " Clear the auto-grad graph from grad var " << var->Name ()
68+ << " because of retain_graph=False when calling backward" ;
69+ var->GradVarBase ()->SetGraphIsFreed (true );
70+ var->GradVarBase ()->ClearGradNode ();
71+ }
6372
64- VLOG (3 ) << " Init first node of backward" ;
73+ if (init_node == nullptr || var->OverridedStopGradient ()) {
74+ VLOG (3 ) << " Skip auto grad since there is no grad op for var or loss is "
75+ " stop_gradient=True: "
76+ << var->Name ();
77+ continue ;
78+ }
6579
66- PADDLE_ENFORCE_EQ (
67- var->HasGradVar (), true ,
68- platform::errors::NotFound (" Grad variable not exist for variable %s" ,
69- var->Name ()));
70-
71- auto & fwd_var = var->Var ().Get <framework::LoDTensor>();
72- auto * grad_var =
73- var->GradVarBase ()->MutableVar ()->GetMutable <framework::LoDTensor>();
74- VLOG (6 ) << " init loss grad:" << var->GradVarBase ()->Name ()
75- << " as stop_gradient false" ;
76- var->GradVarBase ()->InnerSetOverridedStopGradient (false );
77- auto * dev_ctx = platform::DeviceContextPool::Instance ().Get (fwd_var.place ());
78- grad_var->Resize (fwd_var.dims ());
79- grad_var->mutable_data (fwd_var.place (), fwd_var.type ());
80- operators::math::set_constant (*dev_ctx, grad_var, 1.0 );
80+ VLOG (3 ) << " Init node of backward" ;
81+
82+ PADDLE_ENFORCE_EQ (
83+ var->HasGradVar (), true ,
84+ platform::errors::NotFound (" Tensor %s has no gradient" , var->Name ()));
85+
86+ auto & fwd_var = var->Var ().Get <framework::LoDTensor>();
87+ auto * grad_var =
88+ var->GradVarBase ()->MutableVar ()->GetMutable <framework::LoDTensor>();
89+ VLOG (6 ) << " init loss grad:" << var->GradVarBase ()->Name ()
90+ << " as stop_gradient false" ;
91+ var->GradVarBase ()->InnerSetOverridedStopGradient (false );
92+ auto * dev_ctx =
93+ platform::DeviceContextPool::Instance ().Get (fwd_var.place ());
94+ if (grad_tensor == nullptr ) {
95+ grad_var->Resize (fwd_var.dims ());
96+ grad_var->mutable_data (fwd_var.place (), fwd_var.type ());
97+ operators::math::set_constant (*dev_ctx, grad_var, 1.0 );
98+ } else {
99+ paddle::framework::TensorCopy (
100+ grad_tensor->Var ().Get <framework::LoDTensor>(), fwd_var.place (),
101+ *dev_ctx, grad_var);
102+ }
103+
104+ init_nodes_.push_back (init_node);
105+ }
81106}
82107
83108void BasicEngine::CheckBackwardInputs (const OpBase& op) {
@@ -224,8 +249,10 @@ void BasicEngine::PrepareDeps() {
224249 std::queue<GradOpNode*> q;
225250 std::unordered_set<GradOpNode*> visited;
226251
227- q.push (init_node_.get ());
228- visited.insert (init_node_.get ());
252+ for (size_t i = 0 ; i < init_nodes_.size (); ++i) {
253+ q.push (init_nodes_[i].get ());
254+ visited.insert (init_nodes_[i].get ());
255+ }
229256
230257 while (!q.empty ()) {
231258 auto * cur_node = q.front ();
@@ -276,14 +303,16 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
276303}
277304
278305void BasicEngine::Execute () {
279- if (init_node_ == nullptr ) {
306+ if (init_nodes_. empty () ) {
280307 return ;
281308 }
282309
283310 PrepareDeps ();
284311 // Start execute Computation graph
285312 std::queue<std::shared_ptr<GradOpNode>> q;
286- q.push (std::move (init_node_));
313+ for (size_t i = 0 ; i < init_nodes_.size (); ++i) {
314+ q.push (std::move (init_nodes_[i]));
315+ }
287316
288317 size_t op_num = 0 ;
289318
@@ -505,7 +534,7 @@ void BasicEngine::Execute() {
505534}
506535
507536void BasicEngine::Clear () {
508- init_node_. reset ();
537+ init_nodes_. clear ();
509538 node_deps_.clear ();
510539 accumulators_.clear ();
511540 accumulators_with_grad_node_.clear ();
0 commit comments