@@ -29,8 +29,6 @@ limitations under the License. */
2929#include " paddle/operators/detail/simple_block_queue.h"
3030#include " paddle/string/printf.h"
3131
32- #define LISTEN_TERMINATE_MESSAGE " TERMINATE@RECV"
33-
3432namespace paddle {
3533namespace operators {
3634
@@ -95,46 +93,57 @@ class RecvOp : public framework::OperatorBase {
9593 auto param_list = Attr<std::vector<std::string>>(" ParamList" );
9694 auto grad_list = Attr<std::vector<std::string>>(" GradList" );
9795 auto fan_in = Attr<int >(" Fanin" );
98- size_t param_count = param_list.size ();
9996
10097 auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock );
10198 auto *program = block->Program ();
10299 framework::Executor executor (dev_place);
103100
104101 // TODO(typhoonzero): change this to a while_op for every cluster-batch.
105102 bool exit_flag = false ;
106- size_t barrier_size = param_count * fan_in;
107103 while (!exit_flag) {
108104 // Get from multiple trainers, we don't care about the order in which
109105 // the gradients arrives, just add suffix 0~n and merge the gradient.
110106 rpc_service_->SetCond (0 );
111- for (size_t i = 0 ; i < barrier_size; ++i) {
107+ size_t recv_var_cnt = 0 ;
108+ int batch_barrier = 0 ;
109+ while (batch_barrier != fan_in) {
112110 const detail::MessageWithName &v = rpc_service_->Get ();
113111 auto grad_var_name = v.first ;
114112 if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
115113 LOG (INFO) << " received terminate message and exit" ;
116114 exit_flag = true ;
117115 break ;
118- }
119- auto it = std::find (grad_list.begin (), grad_list.end (), grad_var_name);
120- std::string param_var_name;
121- if (it != grad_list.end ()) {
122- param_var_name = param_list[it - grad_list.begin ()];
116+ } else if (grad_var_name == BATCH_BARRIER_MESSAGE) {
117+ VLOG (3 ) << " recv batch barrier message" ;
118+ batch_barrier++;
119+ continue ;
123120 } else {
124- LOG (ERROR) << " grad has no paired param:" << grad_var_name;
125- }
126- VLOG (3 ) << " received grad: " << grad_var_name
127- << " updating param: " << param_var_name;
128- if (fan_in > 1 ) {
129- grad_var_name = this ->GetGradVarNameForTrainer (grad_var_name);
130- }
131- auto *var = recv_scope.FindVar (grad_var_name);
132- if (var == nullptr ) {
133- LOG (ERROR) << " Can not find server side var: " << grad_var_name;
134- PADDLE_THROW (" Can not find server side var" );
121+ // receive a variable
122+ recv_var_cnt++;
123+ auto it =
124+ std::find (grad_list.begin (), grad_list.end (), grad_var_name);
125+ std::string param_var_name;
126+ if (it != grad_list.end ()) {
127+ param_var_name = param_list[it - grad_list.begin ()];
128+ } else {
129+ LOG (ERROR) << " grad has no paired param:" << grad_var_name;
130+ }
131+ VLOG (3 ) << " received grad: " << grad_var_name
132+ << " updating param: " << param_var_name;
133+
134+ if (fan_in > 1 ) {
135+ grad_var_name = this ->GetGradVarNameForTrainer (grad_var_name);
136+ }
137+ auto *var = recv_scope.FindVar (grad_var_name);
138+ if (var == nullptr ) {
139+ LOG (ERROR) << " Can not find server side var: " << grad_var_name;
140+ PADDLE_THROW (" Can not find server side var" );
141+ }
142+ detail::DeserializeFromMessage (v.second , dev_ctx, var);
135143 }
136- detail::DeserializeFromMessage (v.second , dev_ctx, var);
137144 }
145+ VLOG (3 ) << " recv " << recv_var_cnt << " parmeters for one barrier." ;
146+ // TODO(Yancey1989): merge SelectedRows variables here
138147 if (exit_flag) {
139148 break ;
140149 }
@@ -146,7 +155,7 @@ class RecvOp : public framework::OperatorBase {
146155 LOG (ERROR) << " run sub program error " << e.what ();
147156 }
148157 rpc_service_->SetCond (1 );
149- rpc_service_->WaitClientGet (barrier_size );
158+ rpc_service_->WaitClientGet (recv_var_cnt );
150159 grads_counter_.clear ();
151160 } // while(true)
152161 }
0 commit comments