Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 114 additions & 25 deletions python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,119 @@ def minimize_impl(self,
self._wait()
return optimize_ops, params_grads

def _init_pair_comm(self, pair, ring_id):
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)

def _init_npu_pipeline_comm(self, startup_block):
# NOTE(wangxi): some bug with hccl, must set pp_degree be even number
assert (self.pp_degree % 2) == 0

max_ring_id = -1
my_pair = []
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
max_ring_id = max(max_ring_id, ring_id)
logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id))

if self.pp_rank in pair:
my_pair.append(pair)

# for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair = (self.pp_rank,
(self.pp_rank + 1) % self.pp_degree) # 2->3
recv_from_next_pair = ((self.pp_rank + 1) % self.pp_degree,
self.pp_rank) # 3->2
recv_from_prev_pair = ((self.pp_rank - 1 + self.pp_degree) %
self.pp_degree, self.pp_rank) # 1->2
send_to_prev_pair = (self.pp_rank, (self.pp_rank - 1 + self.pp_degree) %
self.pp_degree) # 2->1

even = (self.pp_rank % 2) == 0

# 1. even send to next, odd recv from prev, 0->1, 2->3
pair = send_to_next_pair if even else recv_from_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair)
logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format(pair,
ring_id))

# 2. even recv from next, odd send to prev, 1->0, 3->2
pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
my_pair.remove(pair)
logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair,
ring_id))

# if pp_degree is 2, only need pair(0->1, 1->0)
if self.pp_degree > 2:
# 3. odd send to next, even recv from prev, 1->2, 3->0
pair = send_to_next_pair if not even else recv_from_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1],
max_ring_id + 1) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id))

# 4. odd recv from next, even send to prev, 2->1, 0->3
pair = recv_from_next_pair if not even else send_to_prev_pair
ring_id = self.pp_ring_map.get(
pair[0] * 1000 + pair[1],
max_ring_id + 2) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair)
logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id))

assert len(my_pair) == 0, "Current pipeline does not support cross stage communication, " \
"please check unexpected pair {}".format(my_pair)

def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)

if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block)
return

# GPU
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
logger.info("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank in pair:
self._init_pair_comm(pair, ring_id)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)

def _init_comm(self):

# config sharding & dp groups
Expand Down Expand Up @@ -435,31 +548,7 @@ def _init_comm(self):

# pp ring
if self.pp_degree > 1:
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)

for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank in pair:
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
self._init_pipeline_comm(startup_block)

# pure dp ring
if self.dp_degree > 1:
Expand Down