Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import copy
from copy import deepcopy
import time
from numpy import sort

from paddle.fluid import core
from paddle.fluid import framework
Expand Down Expand Up @@ -728,6 +729,18 @@ def _update_process_mesh(self):
self._update_process_mesh_between_graphs()

def _prepare(self):
def _find_nearest_parent_nodes(sorted_parent_nodes, child_idx):
before_node = None
after_node = None
pos = -1
for pos, (parent_idx, parent_node) in enumerate(sorted_parent_nodes):
if parent_idx > child_idx:
after_node = parent_node
break
if pos > 0:
_, before_node = sorted_parent_nodes[pos - 1]
return before_node, after_node

if self._has_prepared:
return
self._while_op_nodes = {}
Expand All @@ -751,20 +764,16 @@ def _prepare(self):
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
self._array_nodes[array_var_name].append(node.outputs[0])
# TODO: Use dict and name as the key to store the nodes,
# and use the id comparsion to deal with the before or after position
if node.is_var() and node.var() is not None:
if node.node.graph_id() != 0:
for before_node in reversed(all_nodes[:idx]):
if before_node.is_var() and before_node.var() is not None \
and before_node.node.graph_id() == node.node.graph_id() - 1 \
and before_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(before_node, node))
for after_node in all_nodes[idx + 1:]:
if after_node.is_var() and after_node.var() is not None \
and after_node.node.graph_id() == node.node.graph_id() - 1 \
and after_node.var().name() == node.var().name():
parent_nodes = self._dist_context._tensor_nodes_with_same_name[node.node.graph_id() - 1].get(node.var().name(), None)
if parent_nodes is not None:
sorted_parent_nodes = sorted(parent_nodes, key=lambda x: x[0])
for _, parent_node in sorted_parent_nodes:
self._node_pairs_between_graphs.append(
(after_node, node))
(parent_node, node))
self._has_prepared = True

def complete_forward_annotation(self, serial_main_program=None):
Expand All @@ -787,14 +796,22 @@ def complete_forward_annotation(self, serial_main_program=None):

# self._dist_context.validate_dist_attr_for_program()

start_time = time.time()
self._prepare()
print("completion-prepare: ", time.time() - start_time, flush=True)

start_time = time.time()
self._update_process_mesh()
print("completion-mesh: ", time.time() - start_time, flush=True)

start_time = time.time()
self._update_dims_mapping()
print("graph-dims: ", time.time() - start_time, flush=True)

start_time = time.time()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
print("completion-copy: ", time.time() - start_time, flush=True)
else:
self._dist_context.initialize(with_graph=False)

Expand Down
Loading