Skip to content
158 changes: 84 additions & 74 deletions python/paddle/fluid/incubate/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@
from __future__ import print_function

import abc
import sys

from enum import Enum

from paddle.fluid.optimizer import SGD
from paddle.fluid.executor import Executor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you import Executor here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for value check, isinstance of


from role_maker import RoleMakerBase, Role
from role_maker import RoleMakerBase
from role_maker import MPISymetricRoleMaker
from role_maker import UserDefinedRoleMaker


class Mode(Enum):
"""
There are various mode for fleet, each of them is designed for different model.
"""
TRANSPILER = 1,
PSLIB = 2,
COLLECTIVE = 3
Expand All @@ -46,17 +49,11 @@ class Fleet(object):

def __init__(self, mode):
assert isinstance(mode, Mode)
self.is_initialized = False
self.mode = mode
self.workers = 0
self.servers = 0
self.worker_endpoints = []
self.server_endpoints = []
self.role = Role.WORKER
self.current_endpoint = None
self.current_id = 0
self.optimizer = None
self.role_maker_ = None
self._is_initialized = False
self._mode = mode
self._optimizer = None
self._role_maker = None
self._executor = None

def is_first_worker(self):
"""
Expand All @@ -66,25 +63,25 @@ def is_first_worker(self):
bool: True if this is the first node of worker,
False if not.
"""
return self.is_worker() and self.current_id == 0
return self._role_maker.is_first_worker()

def worker_id(self):
def worker_index(self):
"""
Get current worker id.
Get current worker index.

Returns:
int: node id
"""
return self.current_id
return self._role_maker.worker_index()

def get_workers(self):
def worker_num(self):
"""
Get current total worker number.

Returns:
int: worker number
"""
return self.workers
return len(self._role_maker.get_trainer_endpoints())

def is_worker(self):
"""
Expand All @@ -94,7 +91,51 @@ def is_worker(self):
bool: True if this is a node of worker,
False if not.
"""
return self.role == Role.WORKER
return self._role_maker.is_worker()

def worker_endpoints(self, to_string=False):
"""
Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

Returns:
list/string: server endpoints
"""

if to_string:
return ",".join(self._role_maker.get_trainer_endpoints())
else:
return self._role_maker.get_trainer_endpoints()

def server_num(self):
"""
Get current total worker number.

Returns:
int: server number
"""
return len(self._role_maker.get_pserver_endpoints())

def server_index(self):
"""
Get current server index.

Returns:
int: node id
"""
return self._role_maker.server_index()

def server_endpoints(self, to_string=False):
"""
Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].

Returns:
list/string: server endpoints
"""

if to_string:
return ",".join(self._role_maker.get_pserver_endpoints())
else:
return self._role_maker.get_pserver_endpoints()

def is_server(self):
"""
Expand All @@ -104,7 +145,7 @@ def is_server(self):
bool: True if this is a node of server,
False if not.
"""
return self.role == Role.SERVER
return self._role_maker.is_server()

def split_files(self, files):
"""
Expand All @@ -119,8 +160,8 @@ def split_files(self, files):
list: files belongs to this worker.
"""
file_num = len(files)
trainer_id = self.worker_id()
trainer_num = self.get_workers()
trainer_id = self.worker_index()
trainer_num = self.worker_num()
if trainer_num > file_num:
raise ValueError("trainer_num should be <= file_num : "
"%s > %s" % (trainer_num, file_num))
Expand All @@ -132,74 +173,57 @@ def split_files(self, files):
end += length
return files[start:end]

def init(self, role_maker=None):
def init(self, executor, role_maker=None):
"""
should be called only once in user's python scripts,
init() will initialize RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc.

Args:
executor(Executor): The executor to run fleet.
role_maker(RoleMakerBase): subclass of RoleMakerBase.

Returns:
None
"""
if not isinstance(executor, Executor):
raise ValueError("executor must be an instance of Executor")

if role_maker and not isinstance(role_maker, RoleMakerBase):
raise ValueError("role_maker must be an instance of RoleMakerBase")

self.role_maker_ = role_maker

if isinstance(role_maker, MPISymetricRoleMaker):
self.role_maker_._generate_role()
self.role = Role.WORKER if role_maker._is_worker() else Role.SERVER
self.workers = role_maker._worker_num()
self.servers = role_maker._server_num()
self.server_endpoints = role_maker._get_pserver_endpoints()
self.worker_endpoints = role_maker._get_trainer_endpoints()
self.current_id = role_maker._worker_index(
) if role_maker._is_worker() else role_maker._server_index()
self.current_endpoint = self.worker_endpoints[self.current_id] \
if role_maker._is_worker() else self.server_endpoints[self.current_id]
self._role_maker = role_maker
self._role_maker.generate_role()

elif isinstance(role_maker, UserDefinedRoleMaker):
self.current_id = role_maker.current_id
self.current_endpoint = role_maker.current_endpoint
self.workers = role_maker.workers
self.worker_endpoints = role_maker.worker_endpoints
self.servers = role_maker.servers
self.server_endpoints = role_maker.server_endpoints
self.role = role_maker.role
self._role_maker = role_maker

else:
raise ValueError(
"role_maker must be an instance of UserDefinedRoleMaker/MPISymetricRoleMaker"
)

self.is_initialized = True
self._is_initialized = True

@abc.abstractmethod
def init_worker(self, executor):
def init_worker(self):
pass

@abc.abstractmethod
def run_worker(self, executor, main_program=None):
def init_server(self, model_dir=None):
pass

@abc.abstractmethod
def init_server(self, executor, model_dir=None):
pass

@abc.abstractmethod
def run_server(self, executor):
def run_server(self, ):
pass

@abc.abstractmethod
def stop_worker(self):
pass

@abc.abstractmethod
def stop(self, executor):
def stop(self):
pass

@abc.abstractmethod
Expand All @@ -208,7 +232,6 @@ def distributed_optimizer(self, optimizer, strategy=None):

@abc.abstractmethod
def save_inference_model(self,
executor,
dirname,
feeded_var_names,
target_vars,
Expand All @@ -217,21 +240,9 @@ def save_inference_model(self,
pass

@abc.abstractmethod
def save_persistables(self, executor, dirname, main_program=None):
def save_persistables(self, dirname, main_program=None):
pass

def to_string(self):
infos = """
mode = {}
workers = {}
server_endpoints = {}
role = {}
current_endpoint = {}
current_id = {}
""".format(self.mode, self.workers, self.server_endpoints, self.role,
self.current_endpoint, self.current_id)
return infos


class DistributedOptimizer(object):
"""
Expand All @@ -245,7 +256,7 @@ class DistributedOptimizer(object):

Args:
optimizer(Optimizer): subclass of Optimizer.
strategy(dict): the user define config for Optimizer.
strategy(any): the user define config for Optimizer.

Returns:
None
Expand All @@ -257,9 +268,6 @@ def __init__(self, optimizer, strategy=None):
if not isinstance(optimizer, SGD.__bases__):
raise ValueError("optimizer must be an instance of Optimizer")

if strategy and not isinstance(strategy, dict):
raise ValueError("strategy must be an instance of Dict")

self._optimizer = optimizer
self._strategy = strategy

Expand Down Expand Up @@ -317,8 +325,9 @@ def apply_gradients(self, params_grads):

@abc.abstractmethod
def minimize(self,
loss,
startup_program=None,
losses,
scopes=None,
startup_programs=None,
parameter_list=None,
no_grad_set=None):
"""
Expand All @@ -328,8 +337,9 @@ def minimize(self,
`apply_gradients()` into one.

Args:
loss (Variable): loss variable to run optimizations.
startup_program (Program): startup_program for initializing parameters
losses (Variable|Variable List): loss variable to run optimizations.
scopes (Scope| Scope List): scope instance.
startup_programs (Program|Program List): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Expand Down
Loading