diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index f420a06cfbc743..0df0dcb6504dfd 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -15,6 +15,7 @@ import copy from copy import deepcopy import time +from numpy import sort from paddle.fluid import core from paddle.fluid import framework @@ -728,6 +729,18 @@ def _update_process_mesh(self): self._update_process_mesh_between_graphs() def _prepare(self): + def _find_nearest_parent_nodes(sorted_parent_nodes, child_idx): + before_node = None + after_node = None + pos = -1 + for pos, (parent_idx, parent_node) in enumerate(sorted_parent_nodes): + if parent_idx > child_idx: + after_node = parent_node + break + if pos > 0: + _, before_node = sorted_parent_nodes[pos - 1] + return before_node, after_node + if self._has_prepared: return self._while_op_nodes = {} @@ -751,20 +764,16 @@ def _prepare(self): self._array_nodes[array_var_name] = [] self._array_nodes[array_var_name].append(node) self._array_nodes[array_var_name].append(node.outputs[0]) + # TODO: Use dict and name as the key to store the nodes, + # and use the id comparsion to deal with the before or after position if node.is_var() and node.var() is not None: if node.node.graph_id() != 0: - for before_node in reversed(all_nodes[:idx]): - if before_node.is_var() and before_node.var() is not None \ - and before_node.node.graph_id() == node.node.graph_id() - 1 \ - and before_node.var().name() == node.var().name(): - self._node_pairs_between_graphs.append( - (before_node, node)) - for after_node in all_nodes[idx + 1:]: - if after_node.is_var() and after_node.var() is not None \ - and after_node.node.graph_id() == node.node.graph_id() - 1 \ - and after_node.var().name() == node.var().name(): + parent_nodes = self._dist_context._tensor_nodes_with_same_name[node.node.graph_id() - 1].get(node.var().name(), None) + if parent_nodes is not None: + sorted_parent_nodes = sorted(parent_nodes, key=lambda x: x[0]) + for _, parent_node in sorted_parent_nodes: self._node_pairs_between_graphs.append( - (after_node, node)) + (parent_node, node)) self._has_prepared = True def complete_forward_annotation(self, serial_main_program=None): @@ -787,14 +796,22 @@ def complete_forward_annotation(self, serial_main_program=None): # self._dist_context.validate_dist_attr_for_program() + start_time = time.time() self._prepare() + print("completion-prepare: ", time.time() - start_time, flush=True) + start_time = time.time() self._update_process_mesh() + print("completion-mesh: ", time.time() - start_time, flush=True) + start_time = time.time() self._update_dims_mapping() + print("graph-dims: ", time.time() - start_time, flush=True) + start_time = time.time() # Copy the corresponding distributed attribute from graph to serial_main_program self._dist_context.copy_dist_attr_from_graph_to_program() + print("completion-copy: ", time.time() - start_time, flush=True) else: self._dist_context.initialize(with_graph=False) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index deab633b274794..d4c2a2b9673907 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License +import time import copy from collections import defaultdict import paddle.fluid @@ -24,7 +25,7 @@ from .dist_tensor import DistributedTensor from .dist_op import DistributedOperator from .process_mesh import ProcessMesh -from .utils import is_loss_grad_op, is_loss_op +from .utils import is_loss_grad_op, is_loss_op, is_valid_list_index # There always exists a default context for user. And user can set it to another one. _g_default_distributed_context = None @@ -437,13 +438,18 @@ def initialize(self, with_graph=True): if with_graph: set_flags({"FLAGS_convert_all_blocks": True}) + start_time = time.time() self._serial_graph = framework.IrGraph( core.Graph(self._serial_main_program.desc) ) + print("context-graph-build: ", time.time() - start_time, flush=True) self._init_dist_attr_for_graph() + start_time = time.time() self._need_copy_dist_attr_to_graph = False + print("context-graph-dist: ", time.time() - start_time, flush=True) if self._need_copy_dist_attr_to_graph and with_graph: + # print("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% here 1234", flush=True) self.copy_dist_attr_from_program_to_graph() def add_process_mesh(self, process_mesh): @@ -629,29 +635,40 @@ def _init_dist_attr_for_program(self, no_default=False): ) def _order_nodes_by_program_order(self): - def _contains(nodes, target_node): - for node in nodes: - if _node_id(node) == _node_id(target_node): - return True - return False + # def _contains(nodes, target_node): + # for node in nodes: + # if _node_id(node) == _node_id(target_node): + # return True + # return False + start_time = time.time() serial_ordered_tensor_nodes = [] serial_ordered_op_nodes = [] all_nodes = [] + visited = {} for idx, graph in enumerate(self._serial_graph.all_sub_graphs()): for node in graph.all_nodes(): all_nodes.append(node) + print("context-graph-dist-ordering-0: ", time.time() - start_time, flush=True) + start_time = time.time() for node in all_nodes: if node.is_var() and node.var() is not None: serial_ordered_tensor_nodes.append(node) + visited[_node_id(node)] = False if node.is_op() and node.op() is not None: serial_ordered_op_nodes.append(node) + print("context-graph-dist-ordering-1: ", time.time() - start_time, flush=True) + start_time = time.time() serial_ordered_tensor_nodes.sort( key=lambda node: node.node.original_desc_id() ) + print("context-graph-dist-ordering-2: ", time.time() - start_time, flush=True) + start_time = time.time() serial_ordered_op_nodes.sort( key=lambda node: node.node.original_desc_id() ) + print("context-graph-dist-ordering-3: ", time.time() - start_time, flush=True) + start_time = time.time() num_nodes_before = len(serial_ordered_tensor_nodes) + len( serial_ordered_op_nodes ) @@ -659,72 +676,140 @@ def _contains(nodes, target_node): new_serial_ordered_tensor_nodes = [] new_serial_ordered_op_nodes = [] new_serial_ordered_nodes = [] + tmp_time = 0 + # TODO: user a counter for the following sort for op_node in serial_ordered_op_nodes: tensor_nodes = [] for tensor_node in op_node.inputs: - if ( - tensor_node.is_var() - and tensor_node.var() is not None - and not _contains(new_serial_ordered_nodes, tensor_node) - ): + # if ( + # tensor_node.is_var() + # and tensor_node.var() is not None + # and not _contains(new_serial_ordered_nodes, tensor_node) + # ): + # tensor_nodes.append(tensor_node) + # new_serial_ordered_tensor_nodes.append(tensor_node) + if (tensor_node.is_var() + and tensor_node.var() is not None + and not visited[_node_id(tensor_node)]): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) + visited[_node_id(tensor_node)] = True + + inner_start_time = time.time() tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + tmp_time += time.time() - inner_start_time new_serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_nodes.append(op_node) new_serial_ordered_op_nodes.append(op_node) tensor_nodes = [] for tensor_node in op_node.outputs: - if ( - tensor_node.is_var() - and tensor_node.var() is not None - and not _contains(new_serial_ordered_nodes, tensor_node) - ): + # if ( + # tensor_node.is_var() + # and tensor_node.var() is not None + # and not _contains(new_serial_ordered_nodes, tensor_node) + # ): + # tensor_nodes.append(tensor_node) + # new_serial_ordered_tensor_nodes.append(tensor_node) + if (tensor_node.is_var() + and tensor_node.var() is not None + and not visited[_node_id(tensor_node)]): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) + visited[_node_id(tensor_node)] = True + inner_start_time = time.time() tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + tmp_time += time.time() - inner_start_time new_serial_ordered_nodes.extend(tensor_nodes) + print("context-graph-dist-ordering-4: ", tmp_time, flush=True) + print("context-graph-dist-ordering-5: ", time.time() - start_time, flush=True) + start_time = time.time() new_serial_ordered_tensor_nodes.sort( key=lambda node: node.node.original_desc_id() ) + print("context-graph-dist-ordering-6: ", time.time() - start_time, flush=True) + start_time = time.time() new_serial_ordered_op_nodes.sort( key=lambda node: node.node.original_desc_id() ) + print("context-graph-dist-ordering-7: ", time.time() - start_time, flush=True) + start_time = time.time() self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes self._serial_ordered_op_nodes = new_serial_ordered_op_nodes self._serial_ordered_nodes = new_serial_ordered_nodes assert len(self._serial_ordered_nodes) == len( self._serial_ordered_tensor_nodes ) + len(self._serial_ordered_op_nodes) + # TODO: Use [graph_id][tensor_name][node] to store the tensor nodes for completion preparation + # graph_id -> tensor->name -> node_lists + self._tensor_nodes_with_same_name = defaultdict(dict) + for idx, node in enumerate(self._serial_ordered_nodes): + if node.is_var() and node.var() is not None: + graph_id = node.node.graph_id() + tensor_name = node.var().name() + if self._tensor_nodes_with_same_name[graph_id].get(tensor_name, None) is None: + self._tensor_nodes_with_same_name[graph_id][tensor_name] = [] + self._tensor_nodes_with_same_name[graph_id][tensor_name].append((idx, node)) + # for graph_id, graph_nodes in self._tensor_nodes_with_same_name.items(): + # print("graph nodes: ", graph_id, flush=True) + # for tensor_name, tensor_nodes in graph_nodes.items(): + # print("tensor nodes: ", tensor_name, tensor_nodes, flush=True) + + print("context-graph-dist-ordering-8: ", time.time() - start_time, flush=True) + + start_time = time.time() self._serial_orphan_tensor_nodes = [] for tensor_node in serial_ordered_tensor_nodes: - if not _contains(self._serial_ordered_tensor_nodes, tensor_node): + # if not _contains(self._serial_ordered_tensor_nodes, tensor_node): + if not visited[_node_id(tensor_node)]: self._serial_orphan_tensor_nodes.append(tensor_node) if len(self._serial_ordered_nodes) != num_nodes_before: print( "WARNING: there are some orphan tensors or ops which are not used in the execution." ) + print("context-graph-dist-ordering-9: ", time.time() - start_time, flush=True) + for node in serial_ordered_tensor_nodes: + print("[before ordering] t: ", _node_id(node), node.var().name(),flush=True) + for node in serial_ordered_op_nodes: + print("[before ordering] o: ", _node_id(node), node.op().type(), flush=True) + for node in new_serial_ordered_tensor_nodes: + print("[after ordering] t: ", _node_id(node), node.var().name(),flush=True) + for node in new_serial_ordered_op_nodes: + print("[after ordering] o: ", _node_id(node), node.op().type(), flush=True) + for node in self._serial_orphan_tensor_nodes: + print("[after ordering] a: ", _node_id(node), node.var().name(), flush=True) + for node in new_serial_ordered_nodes: + print("[after ordering] o: ", _node_id(node), flush=True) def _init_dist_attr_for_graph(self): # Convert program to graph and initialize the distributed attributes + start_time = time.time() self._order_nodes_by_program_order() + print("context-graph-dist-ordering: ", time.time() - start_time, flush=True) + start_time = time.time() + self._tensor_original_id_to_id = {} + self._op_original_id_to_id = {} + for tensor_id, tensor in self._dist_tensors_for_program.items(): + original_id = tensor.serial_tensor.desc.original_id() + self._tensor_original_id_to_id[original_id] = tensor_id + for op_id, op in self._dist_ops_for_program.items(): + original_id = op.serial_op.desc.original_id() + self._op_original_id_to_id[original_id] = op_id + print("context-graph-dist-mapping: ", time.time() - start_time, flush=True) + start_time = time.time() for node in self.serial_ordered_nodes: if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() - for ( - cur_tensor_id, - cur_dist_tensor, - ) in self._dist_tensors_for_program.items(): - if ( - tensor_id == cur_tensor_id - or tensor_id - == cur_dist_tensor.serial_tensor.desc.original_id() - ): - dist_tensor = cur_dist_tensor - self._node_id_to_tensor_id[ - _node_id(node) - ] = cur_tensor_id + cur_dist_tensor = self._dist_tensors_for_program.get(tensor_id, None) + if cur_dist_tensor is not None: + cur_tensor_id = tensor_id + else: + cur_tensor_id = self._tensor_original_id_to_id[tensor_id] + cur_dist_tensor = self._dist_tensors_for_program.get(cur_tensor_id, None) + dist_tensor = cur_dist_tensor + self._node_id_to_tensor_id[ + _node_id(node) + ] = cur_tensor_id assert ( dist_tensor is not None ), "Tensor must have a distributed tensor after the initialization for program." @@ -738,16 +823,16 @@ def _init_dist_attr_for_graph(self): if node.is_op() and node.op() is not None: dist_op = None op_id = node.node.original_desc_id() - for ( - cur_op_id, - cur_dist_op, - ) in self._dist_ops_for_program.items(): - if ( - op_id == cur_op_id - or op_id == cur_dist_op.serial_op.desc.original_id() - ): - dist_op = cur_dist_op - self._node_id_to_op_id[_node_id(node)] = cur_op_id + cur_dist_op = self._dist_ops_for_program.get(op_id, None) + if cur_dist_op is not None: + cur_op_id = op_id + else: + cur_op_id = self._op_original_id_to_id[op_id] + cur_dist_op = self._dist_ops_for_program.get(cur_op_id, None) + dist_op = cur_dist_op + self._node_id_to_op_id[ + _node_id(node) + ] = cur_op_id assert ( dist_op is not None ), "Operator must have a distributed operator after the initialization for program." @@ -756,6 +841,15 @@ def _init_dist_attr_for_graph(self): dist_op.serial_op, dist_op.dist_attr ) self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + print("context-graph-dist-init: ", time.time() - start_time, flush=True) + for node_id, dist_tensor in self._dist_tensors_for_graph.items(): + print("graph dist tensor: ", node_id, dist_tensor.serial_tensor.desc.id(), flush=True) + for node_id, dist_op in self._dist_ops_for_graph.items(): + print("graph dist op: ", node_id, dist_op.serial_op.desc.id(), flush=True) + for node_id, id in self._node_id_to_tensor_id.items(): + print("graph dist tensor node_id: ", node_id, id, flush=True) + for node_id, id in self._node_id_to_op_id.items(): + print("graph dist op node_id: ", node_id, id, flush=True) def clear_dist_info_for_program(self): self._dist_tensors_for_program.clear() @@ -770,16 +864,13 @@ def copy_dist_attr_from_program_to_graph(self): if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() - for ( - cur_tensor_id, - cur_dist_tensor, - ) in self._dist_tensors_for_program.items(): - if ( - tensor_id == cur_tensor_id - or tensor_id - == cur_dist_tensor.serial_tensor.desc.original_id() - ): - dist_tensor = cur_dist_tensor + cur_dist_tensor = self._dist_tensors_for_program.get(tensor_id, None) + if cur_dist_tensor is not None: + cur_tensor_id = tensor_id + else: + cur_tensor_id = self._tensor_original_id_to_id[tensor_id] + cur_dist_tensor = self._dist_tensors_for_program.get(cur_tensor_id, None) + dist_tensor = cur_dist_tensor assert ( dist_tensor is not None ), "Tensor must have a distributed tensor after the initialization for program." @@ -793,15 +884,13 @@ def copy_dist_attr_from_program_to_graph(self): if node.is_op() and node.op() is not None: dist_op = None op_id = node.node.original_desc_id() - for ( - cur_op_id, - cur_dist_op, - ) in self._dist_ops_for_program.items(): - if ( - op_id == cur_op_id - or op_id == cur_dist_op.serial_op.desc.original_id() - ): - dist_op = cur_dist_op + cur_dist_op = self._dist_ops_for_program.get(op_id, None) + if cur_dist_op is not None: + cur_op_id = op_id + else: + cur_op_id = self._op_original_id_to_id[op_id] + cur_dist_op = self._dist_ops_for_program.get(cur_op_id, None) + dist_op = cur_dist_op assert ( dist_op is not None ), "Operator must have a distributed operator after the initialization for program."