Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions python/paddle/fluid/parallel_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,31 @@ def __init__(self,

if build_strategy is None:
build_strategy = BuildStrategy()
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trainer_id and num_trainers of build_strategy should not be overwritten by the passed trainer_id and num_trainers of ParallelExecutor.


# TODO(paddle-dev): trainer_id and num_trainers should be removed from parameter list.
if num_trainers != 1 and build_strategy.num_trainers != num_trainers:
sys.stderr.write(
'The value of build_strategy.num_trainers[%d] is overwritten '
'by the passed num_trainers[%d].\n' %
(build_strategy.num_trainers, num_trainers))
build_strategy.num_trainers = num_trainers
if trainer_id != 0 and build_strategy.trainer_id != trainer_id:
sys.stderr.write(
'The value of build_strategy.trainer_id[%d] is overwritten '
'by the passed trainer_id[%d].\n' %
(build_strategy.trainer_id, trainer_id))
build_strategy.trainer_id = trainer_id

self._places = framework.cuda_places(
) if use_cuda else framework.cpu_places()
self._scope = scope if scope is not None else executor.global_scope()

if main_program is not None and main_program._enable_dgc:
assert num_trainers > 1, "dgc is not useful when num_trainers <= 1"
assert build_strategy.num_trainers > 1, "dgc is not useful when num_trainers <= 1"
assert build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy.AllReduce, "dgc \
only used for allreduce"

assert num_trainers * len(
assert build_strategy.num_trainers * len(
self._places) > 1, "dgc is not useful for single card training"
assert use_cuda, "dgc only used under cuda"

Expand Down