Skip to content

Commit 2f76bb8

Browse files
authored
[CPU-PSLIB] Add consistency insepection of op's embedding name and sparse table name in config_fleet.py, test=develop (#34249)
1 parent 038883f commit 2f76bb8

File tree

1 file changed

+81
-2
lines changed

1 file changed

+81
-2
lines changed

python/paddle/fluid/incubate/fleet/parameter_server/pslib/optimizer_factory.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .node import DownpourWorker, DownpourServer
2626
from . import ps_pb2 as pslib
2727
import os
28+
import logging
2829

2930
OpRole = core.op_proto_and_checker_maker.OpRole
3031
# this dict is for store info about pull/push sparse ops.
@@ -41,6 +42,10 @@
4142
"scale_sparse_grad": None,
4243
}
4344

45+
logging.basicConfig(
46+
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
47+
logger = logging.getLogger(__name__)
48+
4449

4550
class DistributedOptimizerImplBase(object):
4651
"""
@@ -300,6 +305,74 @@ def _generate_multi_dense_table(self,
300305

301306
return dense_tables, cond2denseid, lists_params, lists_grads, root_params_list, root_grads_list
302307

308+
def _gen_distributed_emb_to_size_dict(self, program):
309+
d_size = dict()
310+
local_vars = program.current_block().vars
311+
312+
for op in program.global_block().ops:
313+
if op.type in self.supported_embedding_types:
314+
if op.attr('is_distributed') is True:
315+
table_name = op.input("W")[0]
316+
emb_size = local_vars[table_name].shape[1]
317+
if d_size.get(table_name) is None:
318+
d_size[table_name] = emb_size
319+
elif d_size[table_name] != emb_size:
320+
raise ValueError("embedding size error: %s vs %s" %
321+
(emb_size, d_size[table_name]))
322+
323+
return d_size
324+
325+
def _check_config_fleet_with_program_op(self, strategy, table_name,
326+
emb_to_size):
327+
if strategy.get(table_name) is None:
328+
strategy[table_name] = dict()
329+
st = strategy[table_name]
330+
331+
accessor = None
332+
if st.get("sparse_accessor_class") is not None:
333+
accessor = st["sparse_accessor_class"]
334+
335+
if accessor is None:
336+
accessor = "DownpourCtrAccessor"
337+
338+
# set sparse_embedx_dim in strategy,
339+
# user do not have to set it in config_fleet
340+
if accessor == "DownpourFeatureValueAccessor" \
341+
or accessor == "DownpourCtrAccessor" \
342+
or accessor == "DownpourDoubleUnitAccessor" \
343+
or accessor == "DownpourUnitAccessor":
344+
if st.get("sparse_embedx_dim") is not None \
345+
and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3:
346+
raise ValueError("fleet config sparse_embedx_dim=%s not"
347+
" equal to embedding size - 3 = %s" %
348+
(st["sparse_embedx_dim"],
349+
emb_to_size[table_name] - 3))
350+
if st.get("sparse_embedx_dim") is None:
351+
logger.warning(
352+
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
353+
"with same sparse table name is not set in config_fleet.py. "
354+
"Hence automatically set sparse_embedx_dim = {} - 3.".
355+
format(table_name, emb_to_size[table_name], emb_to_size[
356+
table_name]))
357+
st["sparse_embedx_dim"] = emb_to_size[table_name] - 3
358+
elif accessor == "DownpourSparseValueAccessor":
359+
if st.get("sparse_embedx_dim") is not None \
360+
and st["sparse_embedx_dim"] != emb_to_size[table_name]:
361+
raise ValueError("fleet config sparse_embedx_dim=%s not"
362+
" equal to embedding size = %s" %
363+
(st["sparse_embedx_dim"],
364+
emb_to_size[table_name]))
365+
if st.get("sparse_embedx_dim") is None:
366+
logger.warning(
367+
"sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim "
368+
"with same sparse table name is not set in config_fleet.py. "
369+
"Hence automatically set sparse_embedx_dim = {}.".format(
370+
table_name, emb_to_size[table_name], emb_to_size[
371+
table_name]))
372+
st["sparse_embedx_dim"] = emb_to_size[table_name]
373+
374+
return strategy
375+
303376
def _minimize(self,
304377
losses,
305378
startup_program=None,
@@ -397,6 +470,10 @@ def _minimize(self,
397470
sparse_table_to_index[tn] = sparse_table_index
398471
sparse_table_index += 1
399472

473+
# get {table_name: emb_size} dict from program ops
474+
emb_to_size = self._gen_distributed_emb_to_size_dict(
475+
loss.block.program)
476+
400477
# get inputs_dict
401478
inputs_dict = self._find_distributed_lookup_table_inputs(
402479
loss.block.program, sparse_table)
@@ -511,8 +588,10 @@ def _minimize(self,
511588
# ServerParameter add all sparse tables
512589
for tn in sparse_table_to_index:
513590
sparse_table_index = sparse_table_to_index[tn]
514-
if strategy.get(tn) is not None:
515-
server.add_sparse_table(sparse_table_index, strategy[tn])
591+
st = self._check_config_fleet_with_program_op(strategy, tn,
592+
emb_to_size)
593+
if st.get(tn) is not None:
594+
server.add_sparse_table(sparse_table_index, st[tn])
516595
else:
517596
server.add_sparse_table(sparse_table_index, None)
518597

0 commit comments

Comments
 (0)