Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/paddle/distributed/auto_parallel/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
from collections import defaultdict
from paddle.fluid import core


class TensorDistributedAttribute:
Expand Down Expand Up @@ -77,6 +78,8 @@ def mark_as_parameter(self):
self._is_parameter = True

def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
Expand Down Expand Up @@ -222,6 +225,8 @@ def mark_as_parameter(self, name):
self._is_parameters[name] = True

def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
Expand Down
37 changes: 37 additions & 0 deletions python/paddle/distributed/auto_parallel/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
from .interface import _g_process_mesh_map

# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
Expand Down Expand Up @@ -49,6 +51,20 @@ def __init__(self):
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
# The following is a hard code and will be removed in the future
self._data_parallel_axis = None
self._model_parallel_axis = None
self._process_mesh = _g_process_mesh_map.get(0, None)
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this dp/mp strategy in the next step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got, this will be deleted in next major in Sep. Auto Parallel will NOT hold a global view of MP/DP/PP

self._model_parallel_axis = -1

def is_initialized_for_program(self):
return self._is_initialized_for_program
Expand Down Expand Up @@ -99,6 +115,19 @@ def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr

def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh
if self._process_mesh is not None:
if self._process_mesh.ndim == 1:
self._data_parallel_axis = 0
self._model_parallel_axis = 0
else:
self._data_parallel_axis = 0
self._model_parallel_axis = 1
else:
self._data_parallel_axis = -1
self._model_parallel_axis = -1

def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
Expand Down Expand Up @@ -377,3 +406,11 @@ def amend_distributed_attr_for_program(self):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1

def _get_data_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._data_parallel_axis, self._process_mesh

def _get_model_parallel_info(self):
# This function is a hard code, and will be obsoleted in the future
return self._model_parallel_axis, self._process_mesh
14 changes: 14 additions & 0 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def parent(self):
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]

@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)

def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
Expand Down Expand Up @@ -229,6 +236,13 @@ def set_placement(self, order):
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]

def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""

_g_process_mesh_map = dict()

def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def get_impls(self):
class DistributedOperatorImpl:
def __init__(self):
self._name = None
self._forward_implemented = False
self._backward_implemented = False

def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
Expand Down
112 changes: 112 additions & 0 deletions python/paddle/distributed/auto_parallel/operators/dist_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from ..process import new_process_group
from ..utils import _get_comm_group


class DistributedEmbedding(DistributedOperator):
Expand All @@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
self._forward_implemented = True
self._backward_implemented = False

def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
Expand Down Expand Up @@ -92,6 +100,110 @@ def update_dims_mapping(self, op_dist_attr):

return changed

def forward(self, serial_op):
def static_handle(dst_block,
src_op,
op_dist_attr,
input_name_mapping,
output_name_mapping,
rank_id=0):
assert len(
input_name_mapping
) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format(
input_name_mapping)
assert len(
output_name_mapping
) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format(
output_name_mapping)
assert len(
input_name_mapping['Ids']
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
input_name_mapping['Ids'])
assert len(
input_name_mapping['W']
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
input_name_mapping['W'])
assert len(
output_name_mapping['Out']
) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format(
input_name_mapping['Out'])

Ids_var = dst_block.var(input_name_mapping['Ids'][0])
Weight_var = dst_block.var(input_name_mapping['W'][0])
Out_var = dst_block.var(output_name_mapping['Out'][0])

# got dist attribute info
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
Weight_var.name)[0]
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group

# caculate embedding offset
# TODO generalize here, using cartisian product to allow any dimensional mesh shape
mesh_shape = len(process_mesh_shape)
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format(
process_mesh_shape)
num_partition = process_mesh_shape[embedding_row_dim_mapping]
# TODO generalize here, support any mesh group
if mesh_shape == 1:
relative_idx = process_mesh_group.index(rank_id)
else:
relative_idx = rank_id % num_partition

per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size

# TODO caculate ring id
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
)._get_model_parallel_info()
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology,
model_parallel_axis, rank_id)
group = new_process_group(group_ranks)

# append op
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
'c_embedding')

intermediate_var_0 = dst_block.create_var(
name=unique_name.generate_with_ignorable_key(".".join(
["c_embedding", 'tmp'])),
dtype=Weight_var.dtype,
shape=Out_var.shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=Out_var.stop_gradient)

check_variable_and_dtype(
Out_var, 'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'c_allreduce_sum')

dst_block.append_op(
type='c_embedding',
inputs={'Ids': [Ids_var],
'W': [Weight_var]},
outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx})

# use_model_parallel
dst_block.append_op(
type='c_allreduce_sum',
inputs={'X': [intermediate_var_0]},
outputs={'Out': [Out_var]},
attrs={
'ring_id': group.id,
'use_calc_stream': True,
'use_model_parallel': True,
})

if in_dygraph_mode():
raise NotImplementedError(
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
"matmul", 0))
else:
return static_handle


register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel"))
Loading