@@ -33,6 +33,10 @@ def __str__(self):
3333 return "%s:%d:%d" % (self .varname , self .offset , self .size )
3434
3535
36+ def same_or_split_var (p_name , var_name ):
37+ return p_name == var_name or p_name .startswith (var_name + ".block" )
38+
39+
3640def split_dense_variable (var_list ,
3741 pserver_count ,
3842 min_block_size = 1024 ,
@@ -303,8 +307,8 @@ def _is_op_on_pserver(self, endpoint, all_ops, idx):
303307 return True
304308 else :
305309 for n in param_names :
306- if n . startswith ( op .inputs ["Param" ]. name + ".block" ) and \
307- n != op .inputs ["Param" ].name :
310+ if same_or_split_var ( n , op .inputs [
311+ "Param" ]. name ) and n != op .inputs ["Param" ].name :
308312 return True
309313 return False
310314 else :
@@ -335,7 +339,7 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
335339 if key == "Grad" :
336340 grad_block = None
337341 for g in self .param_grad_ep_mapping [endpoint ]["grads" ]:
338- if g .name . startswith ( var .name ):
342+ if same_or_split_var ( g .name , var .name ):
339343 grad_block = g
340344 break
341345 if not grad_block :
@@ -365,7 +369,7 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
365369 # param is already created on global program
366370 param_block = None
367371 for p in self .param_grad_ep_mapping [endpoint ]["params" ]:
368- if p .name . startswith ( var .name ):
372+ if same_or_split_var ( p .name , var .name ):
369373 param_block = p
370374 break
371375 if not param_block :
@@ -502,7 +506,7 @@ def get_startup_program(self, endpoint, pserver_program):
502506 def _get_splited_name_and_shape (varname ):
503507 for idx , splited_param in enumerate (params ):
504508 pname = splited_param .name
505- if pname . startswith ( varname ) and varname != pname :
509+ if same_or_split_var ( pname , varname ) and varname != pname :
506510 return pname , splited_param .shape
507511 return "" , []
508512
0 commit comments