Skip to content

Commit 39b0fdc

Browse files
committed
Transpiler: fix pserver crash due to split var name check.
In notest_dist_label_semantic_roles.py, "emb" is matched with "embedding_1.w_0", but they are two irrevalent vars. Fixes: #7701
1 parent f9fe48e commit 39b0fdc

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

python/paddle/v2/fluid/distribute_transpiler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def __str__(self):
3333
return "%s:%d:%d" % (self.varname, self.offset, self.size)
3434

3535

36+
def same_or_split_var(p_name, var_name):
37+
return p_name == var_name or p_name.startswith(var_name + ".block")
38+
39+
3640
def split_dense_variable(var_list,
3741
pserver_count,
3842
min_block_size=1024,
@@ -303,8 +307,8 @@ def _is_op_on_pserver(self, endpoint, all_ops, idx):
303307
return True
304308
else:
305309
for n in param_names:
306-
if n.startswith(op.inputs["Param"].name+".block") and \
307-
n != op.inputs["Param"].name:
310+
if same_or_split_var(n, op.inputs[
311+
"Param"].name) and n != op.inputs["Param"].name:
308312
return True
309313
return False
310314
else:
@@ -335,7 +339,7 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
335339
if key == "Grad":
336340
grad_block = None
337341
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
338-
if g.name.startswith(var.name):
342+
if same_or_split_var(g.name, var.name):
339343
grad_block = g
340344
break
341345
if not grad_block:
@@ -365,7 +369,7 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
365369
# param is already created on global program
366370
param_block = None
367371
for p in self.param_grad_ep_mapping[endpoint]["params"]:
368-
if p.name.startswith(var.name):
372+
if same_or_split_var(p.name, var.name):
369373
param_block = p
370374
break
371375
if not param_block:
@@ -502,7 +506,7 @@ def get_startup_program(self, endpoint, pserver_program):
502506
def _get_splited_name_and_shape(varname):
503507
for idx, splited_param in enumerate(params):
504508
pname = splited_param.name
505-
if pname.startswith(varname) and varname != pname:
509+
if same_or_split_var(pname, varname) and varname != pname:
506510
return pname, splited_param.shape
507511
return "", []
508512

0 commit comments

Comments
 (0)