Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ message AsyncConfig {
optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ];
optional bool launch_barrier = 9 [ default = true ];
optional string heter_worker_device_guard = 10 [ default = 'cpu' ];
}

message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
Expand Down
174 changes: 89 additions & 85 deletions python/paddle/distributed/fleet/base/role_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,13 +530,6 @@ def _get_heter_worker_endpoint(self):
return self._heter_trainer_endpoints[(self._current_id) %
self._heter_worker_num()]

def _get_heter_worker_device(self):
"""
Returns:
string: heter_trainer's device of current node, e.g: CPU/GPU/XPU
"""
return self._heter_trainer_device.upper()


class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self, is_collective=False, **kwargs):
Expand Down Expand Up @@ -677,88 +670,99 @@ def _is_heter_worker(self):
return self._role == Role.HETER_WORKER

def _ps_env(self):
try:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")

if self._server_endpoints is None:
# back to non_distributed execution.
self._server_endpoints = ""
self._trainers_num = 1
self._role = Role.WORKER
self._current_id = 0
self._nodes_num = 1
self._heter_trainers_num = 0
self._heter_trainer_endpoints = None
self._non_distributed = True
return

self._server_endpoints = self._server_endpoints.split(",")

self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
if self._worker_endpoints:
self._worker_endpoints = self._worker_endpoints.split(",")
else:
self._worker_endpoints = []
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)

if self._server_endpoints is None:
# back to non_distributed execution.
self._server_endpoints = ""
self._trainers_num = 1
self._role = Role.WORKER
self._current_id = 0
self._nodes_num = 1
self._heter_trainers_num = 0
self._heter_trainer_endpoints = None
self._non_distributed = True
return

self._server_endpoints = self._server_endpoints.split(",")

self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
if self._worker_endpoints != None:
self._worker_endpoints = self._worker_endpoints.split(",")
else:
self._worker_endpoints = []

trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
if trainers_num == None:
raise ValueError(
"Can not find PADDLE_TRAINERS_NUM, please check your environment."
)
trainers_num = int(trainers_num)

trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
training_role = os.environ["TRAINING_ROLE"]
training_role = os.getenv("TRAINING_ROLE", None)
if training_role == None:
raise ValueError(
"Can not find TRAINING_ROLE, please check your environment.")

if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.".
format(training_role))

# For heter parameter server env setting
heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST",
"")
if heter_trainer_eplist != "":
try:
heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
except:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.".
format(training_role))

# For heter parameter server env setting
heter_trainer_eplist = os.getenv(
"PADDLE_HETER_TRAINER_IP_PORT_LIST", None)
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE",
None)
if heter_trainer_eplist and heter_trainer_device:
try:
heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
except:
raise ValueError(
"Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)

self._is_heter_parameter_server_mode = True
heter_trainers_num = len(heter_trainer_eplist)
current_node_device = heter_trainer_device.upper()
if current_node_device not in ["CPU", "GPU", "XPU"]:
raise ValueError(
"Heter Trainer doesn't support {} device now, please use CPU / GPU / XPU(KunLun)".
format(heter_trainer_device))
self._heter_trainer_device = current_node_device
else:
self._is_heter_parameter_server_mode = False
heter_trainers_num = 0

if training_role == "TRAINER":
role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"])
if len(self._worker_endpoints) > 0:
self._cur_endpoint = self._worker_endpoints[current_id]
elif training_role == "PSERVER":
role = Role.SERVER
port = os.environ["PADDLE_PORT"]
ip = os.environ["POD_IP"]
self._cur_endpoint = ip + ":" + port
current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER
cur_ip = os.environ["POD_IP"]
cur_port = os.environ["PADDLE_PORT"]
curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint)
else:
"Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)

self._is_heter_parameter_server_mode = True
heter_trainers_num = len(heter_trainer_eplist)
else:
self._is_heter_parameter_server_mode = False
heter_trainers_num = 0

if training_role == "TRAINER":
role = Role.WORKER
current_id = os.getenv("PADDLE_TRAINER_ID", None)
if current_id == None:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER")
except ValueError as e:
raise ValueError(
"Something wrong with PaddleCloud, please check environment")
"Can not find PADDLE_TRAINER_ID, please check your environment."
)
current_id = int(current_id)
if len(self._worker_endpoints) > 0:
self._cur_endpoint = self._worker_endpoints[current_id]
elif training_role == "PSERVER":
role = Role.SERVER
port = os.getenv("PADDLE_PORT", None)
if port == None:
raise ValueError(
"Can not find PADDLE_PORT, please check your environment.")
ip = os.getenv("POD_IP", None)
if ip == None:
raise ValueError(
"Can not find POD_IP, please check your environment.")
self._cur_endpoint = ip + ":" + port
current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER
cur_port = os.getenv("PADDLE_PORT", None)
if cur_port == None:
raise ValueError(
"Can not find PADDLE_PORT, please check your environment.")
cur_ip = os.getenv("POD_IP", None)
if cur_ip == None:
raise ValueError(
"Can not find POD_IP, please check your environment.")
curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint)

self._trainers_num = trainers_num
self._role = role
Expand Down
Loading