|
25 | 25 | from .node import DownpourWorker, DownpourServer |
26 | 26 | from . import ps_pb2 as pslib |
27 | 27 | import os |
| 28 | +import logging |
28 | 29 |
|
29 | 30 | OpRole = core.op_proto_and_checker_maker.OpRole |
30 | 31 | # this dict is for store info about pull/push sparse ops. |
|
41 | 42 | "scale_sparse_grad": None, |
42 | 43 | } |
43 | 44 |
|
| 45 | +logging.basicConfig( |
| 46 | + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) |
| 47 | +logger = logging.getLogger(__name__) |
| 48 | + |
44 | 49 |
|
45 | 50 | class DistributedOptimizerImplBase(object): |
46 | 51 | """ |
@@ -300,6 +305,74 @@ def _generate_multi_dense_table(self, |
300 | 305 |
|
301 | 306 | return dense_tables, cond2denseid, lists_params, lists_grads, root_params_list, root_grads_list |
302 | 307 |
|
| 308 | + def _gen_distributed_emb_to_size_dict(self, program): |
| 309 | + d_size = dict() |
| 310 | + local_vars = program.current_block().vars |
| 311 | + |
| 312 | + for op in program.global_block().ops: |
| 313 | + if op.type in self.supported_embedding_types: |
| 314 | + if op.attr('is_distributed') is True: |
| 315 | + table_name = op.input("W")[0] |
| 316 | + emb_size = local_vars[table_name].shape[1] |
| 317 | + if d_size.get(table_name) is None: |
| 318 | + d_size[table_name] = emb_size |
| 319 | + elif d_size[table_name] != emb_size: |
| 320 | + raise ValueError("embedding size error: %s vs %s" % |
| 321 | + (emb_size, d_size[table_name])) |
| 322 | + |
| 323 | + return d_size |
| 324 | + |
| 325 | + def _check_config_fleet_with_program_op(self, strategy, table_name, |
| 326 | + emb_to_size): |
| 327 | + if strategy.get(table_name) is None: |
| 328 | + strategy[table_name] = dict() |
| 329 | + st = strategy[table_name] |
| 330 | + |
| 331 | + accessor = None |
| 332 | + if st.get("sparse_accessor_class") is not None: |
| 333 | + accessor = st["sparse_accessor_class"] |
| 334 | + |
| 335 | + if accessor is None: |
| 336 | + accessor = "DownpourCtrAccessor" |
| 337 | + |
| 338 | + # set sparse_embedx_dim in strategy, |
| 339 | + # user do not have to set it in config_fleet |
| 340 | + if accessor == "DownpourFeatureValueAccessor" \ |
| 341 | + or accessor == "DownpourCtrAccessor" \ |
| 342 | + or accessor == "DownpourDoubleUnitAccessor" \ |
| 343 | + or accessor == "DownpourUnitAccessor": |
| 344 | + if st.get("sparse_embedx_dim") is not None \ |
| 345 | + and st["sparse_embedx_dim"] != emb_to_size[table_name] - 3: |
| 346 | + raise ValueError("fleet config sparse_embedx_dim=%s not" |
| 347 | + " equal to embedding size - 3 = %s" % |
| 348 | + (st["sparse_embedx_dim"], |
| 349 | + emb_to_size[table_name] - 3)) |
| 350 | + if st.get("sparse_embedx_dim") is None: |
| 351 | + logger.warning( |
| 352 | + "sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " |
| 353 | + "with same sparse table name is not set in config_fleet.py. " |
| 354 | + "Hence automatically set sparse_embedx_dim = {} - 3.". |
| 355 | + format(table_name, emb_to_size[table_name], emb_to_size[ |
| 356 | + table_name])) |
| 357 | + st["sparse_embedx_dim"] = emb_to_size[table_name] - 3 |
| 358 | + elif accessor == "DownpourSparseValueAccessor": |
| 359 | + if st.get("sparse_embedx_dim") is not None \ |
| 360 | + and st["sparse_embedx_dim"] != emb_to_size[table_name]: |
| 361 | + raise ValueError("fleet config sparse_embedx_dim=%s not" |
| 362 | + " equal to embedding size = %s" % |
| 363 | + (st["sparse_embedx_dim"], |
| 364 | + emb_to_size[table_name])) |
| 365 | + if st.get("sparse_embedx_dim") is None: |
| 366 | + logger.warning( |
| 367 | + "sparse embedding size for table name '{}' is: {}, while sparse_embedx_dim " |
| 368 | + "with same sparse table name is not set in config_fleet.py. " |
| 369 | + "Hence automatically set sparse_embedx_dim = {}.".format( |
| 370 | + table_name, emb_to_size[table_name], emb_to_size[ |
| 371 | + table_name])) |
| 372 | + st["sparse_embedx_dim"] = emb_to_size[table_name] |
| 373 | + |
| 374 | + return strategy |
| 375 | + |
303 | 376 | def _minimize(self, |
304 | 377 | losses, |
305 | 378 | startup_program=None, |
@@ -397,6 +470,10 @@ def _minimize(self, |
397 | 470 | sparse_table_to_index[tn] = sparse_table_index |
398 | 471 | sparse_table_index += 1 |
399 | 472 |
|
| 473 | + # get {table_name: emb_size} dict from program ops |
| 474 | + emb_to_size = self._gen_distributed_emb_to_size_dict( |
| 475 | + loss.block.program) |
| 476 | + |
400 | 477 | # get inputs_dict |
401 | 478 | inputs_dict = self._find_distributed_lookup_table_inputs( |
402 | 479 | loss.block.program, sparse_table) |
@@ -511,8 +588,10 @@ def _minimize(self, |
511 | 588 | # ServerParameter add all sparse tables |
512 | 589 | for tn in sparse_table_to_index: |
513 | 590 | sparse_table_index = sparse_table_to_index[tn] |
514 | | - if strategy.get(tn) is not None: |
515 | | - server.add_sparse_table(sparse_table_index, strategy[tn]) |
| 591 | + st = self._check_config_fleet_with_program_op(strategy, tn, |
| 592 | + emb_to_size) |
| 593 | + if st.get(tn) is not None: |
| 594 | + server.add_sparse_table(sparse_table_index, st[tn]) |
516 | 595 | else: |
517 | 596 | server.add_sparse_table(sparse_table_index, None) |
518 | 597 |
|
|
0 commit comments