Skip to content

Commit a622b70

Browse files
authored
[Auto Parallel] Logical Partition & Dist Op (#35117)
* support shard reader * support shard reader * add parallel mode * update process mesh * add method to compute comm_group * implement dist_embedding forward func * implement dist matmul forward func * implement dist reshape forward func * add transpiler framework * add transpiler forward * implement transpiler forward * implement transpiler backward & update * add process * add unitest * chmod * chmod * chmod * update unitest * add unitest for gpt * remove unused print * rename transpiler --> partitioner * rename transpiler --> partitioner * chmod * chmod * bug fixed * remove amp function * update case for dp mode * update case for dp mode
1 parent 280d742 commit a622b70

File tree

14 files changed

+3515
-3
lines changed

14 files changed

+3515
-3
lines changed

python/paddle/distributed/auto_parallel/attribute.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import copy
1616
from collections import defaultdict
17+
from paddle.fluid import core
1718

1819

1920
class TensorDistributedAttribute:
@@ -77,6 +78,8 @@ def mark_as_parameter(self):
7778
self._is_parameter = True
7879

7980
def is_valid(self):
81+
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
82+
return True
8083
tensor_shape = self.get_owner_tensor().desc.shape()
8184
if len(tensor_shape) != len(self.get_dims_mapping()):
8285
return False
@@ -222,6 +225,8 @@ def mark_as_parameter(self, name):
222225
self._is_parameters[name] = True
223226

224227
def is_valid(self):
228+
if "read" in self.get_owner_op().type:
229+
return True
225230
for name in self.get_owner_op().desc.input_arg_names():
226231
dims_mapping = self.get_input_dims_mapping(name)
227232
shape = self.get_input_shape(name)

python/paddle/distributed/auto_parallel/context.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import copy
1616
from collections import defaultdict
1717
from paddle.fluid import framework
18+
from paddle.fluid import core
1819
from .attribute import TensorDistributedAttribute
1920
from .attribute import OperatorDistributedAttribute
2021
from .utils import append_distributed_attr_suffix
22+
from .interface import _g_process_mesh_map
2123

2224
# There always exists a default context for user. And user can set it to another one.
2325
DEFAULT_DISTRIBUTED_CONTEXT = None
@@ -49,6 +51,20 @@ def __init__(self):
4951
self._op_distributed_attr_map_for_program = {}
5052
self._tensor_distributed_attr_map_for_graph = {}
5153
self._op_distributed_attr_map_for_graph = {}
54+
# The following is a hard code and will be removed in the future
55+
self._data_parallel_axis = None
56+
self._model_parallel_axis = None
57+
self._process_mesh = _g_process_mesh_map.get(0, None)
58+
if self._process_mesh is not None:
59+
if self._process_mesh.ndim == 1:
60+
self._data_parallel_axis = 0
61+
self._model_parallel_axis = 0
62+
else:
63+
self._data_parallel_axis = 0
64+
self._model_parallel_axis = 1
65+
else:
66+
self._data_parallel_axis = -1
67+
self._model_parallel_axis = -1
5268

5369
def is_initialized_for_program(self):
5470
return self._is_initialized_for_program
@@ -99,6 +115,19 @@ def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
99115
op_node_id = op_node.id()
100116
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
101117

118+
def set_process_mesh(self, process_mesh):
119+
self._process_mesh = process_mesh
120+
if self._process_mesh is not None:
121+
if self._process_mesh.ndim == 1:
122+
self._data_parallel_axis = 0
123+
self._model_parallel_axis = 0
124+
else:
125+
self._data_parallel_axis = 0
126+
self._model_parallel_axis = 1
127+
else:
128+
self._data_parallel_axis = -1
129+
self._model_parallel_axis = -1
130+
102131
def initialize_distributed_attr_for_program(self, program):
103132
if self._is_initialized_for_program:
104133
return
@@ -377,3 +406,11 @@ def amend_distributed_attr_for_program(self):
377406
if dims_mapping[i] != -1 and process_mesh_shape[
378407
dims_mapping[i]] > tensor_shape[i]:
379408
dims_mapping[i] = -1
409+
410+
def _get_data_parallel_info(self):
411+
# This function is a hard code, and will be obsoleted in the future
412+
return self._data_parallel_axis, self._process_mesh
413+
414+
def _get_model_parallel_info(self):
415+
# This function is a hard code, and will be obsoleted in the future
416+
return self._model_parallel_axis, self._process_mesh

python/paddle/distributed/auto_parallel/interface.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def parent(self):
184184
"parent with id %d does not exist." % self._parent_id)
185185
return _g_process_mesh_map[self._parent_id]
186186

187+
@property
188+
def ndim(self):
189+
r"""
190+
Get the number of dimension of ProcessMesh.
191+
"""
192+
return len(self._topology)
193+
187194
def set_placement(self, order):
188195
"""
189196
Set the map from logical processes to physical ones using the
@@ -229,6 +236,13 @@ def set_placement(self, order):
229236
for idx, l_id in enumerate(logical_order):
230237
_user_defined_physical_map[l_id] = order[idx]
231238

239+
def _reset_global_process_mesh_map(self):
240+
"""
241+
Remove all process mesh in _g_process_mesh_map, make it empty.
242+
"""
243+
244+
_g_process_mesh_map = dict()
245+
232246
def __eq__(self, other):
233247
assert other and isinstance(other, ProcessMesh)
234248
if self.topology != other.topology or self.process_group != other.process_group:

python/paddle/distributed/auto_parallel/operators/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def get_impls(self):
3333
class DistributedOperatorImpl:
3434
def __init__(self):
3535
self._name = None
36+
self._forward_implemented = False
37+
self._backward_implemented = False
3638

3739
def forward(self, dist_ctx, *args, **kwargs):
3840
raise NotImplementedError("Please Implement this method in Subclass.")

python/paddle/distributed/auto_parallel/operators/dist_embedding.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
from ..utils import compute_compatible_dim_mapping
2323
from ..utils import compute_compatible_dims_mapping
2424
from ..utils import compute_compatible_and_update_dim_mapping
25+
from paddle.fluid import core, unique_name
26+
from paddle.fluid.framework import in_dygraph_mode
27+
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
28+
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
29+
from ..process import new_process_group
30+
from ..utils import _get_comm_group
2531

2632

2733
class DistributedEmbedding(DistributedOperator):
@@ -39,6 +45,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
3945
def __init__(self, name):
4046
super(DistributedEmbeddingImpl, self).__init__()
4147
self._name = name
48+
self._forward_implemented = True
49+
self._backward_implemented = False
4250

4351
def is_process_mesh_compatible(self, op_dist_attr):
4452
""" No restriction for now. """
@@ -92,6 +100,110 @@ def update_dims_mapping(self, op_dist_attr):
92100

93101
return changed
94102

103+
def forward(self, serial_op):
104+
def static_handle(dst_block,
105+
src_op,
106+
op_dist_attr,
107+
input_name_mapping,
108+
output_name_mapping,
109+
rank_id=0):
110+
assert len(
111+
input_name_mapping
112+
) == 2, "row_parallel_embedding take 2 inputs variable but got {}".format(
113+
input_name_mapping)
114+
assert len(
115+
output_name_mapping
116+
) == 1, "row_parallel_embedding take 2 inputs variable but got {}".format(
117+
output_name_mapping)
118+
assert len(
119+
input_name_mapping['Ids']
120+
) == 1, "row_parallel_embedding input Ids take 1 variable but got {}".format(
121+
input_name_mapping['Ids'])
122+
assert len(
123+
input_name_mapping['W']
124+
) == 1, "row_parallel_embedding input W take 1 variable but got {}".format(
125+
input_name_mapping['W'])
126+
assert len(
127+
output_name_mapping['Out']
128+
) == 1, "row_parallel_embedding input Out take 1 variable but got {}".format(
129+
input_name_mapping['Out'])
130+
131+
Ids_var = dst_block.var(input_name_mapping['Ids'][0])
132+
Weight_var = dst_block.var(input_name_mapping['W'][0])
133+
Out_var = dst_block.var(output_name_mapping['Out'][0])
134+
135+
# got dist attribute info
136+
embedding_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
137+
Weight_var.name)[0]
138+
process_mesh_shape = op_dist_attr.get_process_mesh().topology
139+
process_mesh_group = op_dist_attr.get_process_mesh().process_group
140+
141+
# caculate embedding offset
142+
# TODO generalize here, using cartisian product to allow any dimensional mesh shape
143+
mesh_shape = len(process_mesh_shape)
144+
assert mesh_shape <= 2, "row_parallel_embedding only support 1 or 2 dimensional process mesh, but got {}".format(
145+
process_mesh_shape)
146+
num_partition = process_mesh_shape[embedding_row_dim_mapping]
147+
# TODO generalize here, support any mesh group
148+
if mesh_shape == 1:
149+
relative_idx = process_mesh_group.index(rank_id)
150+
else:
151+
relative_idx = rank_id % num_partition
152+
153+
per_part_size = Weight_var.shape[0]
154+
relative_idx = relative_idx * per_part_size
155+
156+
# TODO caculate ring id
157+
model_parallel_axis, process_mesh = op_dist_attr.get_owner_context(
158+
)._get_model_parallel_info()
159+
group_ranks = _get_comm_group(process_mesh.process_group,
160+
process_mesh.topology,
161+
model_parallel_axis, rank_id)
162+
group = new_process_group(group_ranks)
163+
164+
# append op
165+
check_variable_and_dtype(Ids_var, 'input', ['int32', 'int64'],
166+
'c_embedding')
167+
168+
intermediate_var_0 = dst_block.create_var(
169+
name=unique_name.generate_with_ignorable_key(".".join(
170+
["c_embedding", 'tmp'])),
171+
dtype=Weight_var.dtype,
172+
shape=Out_var.shape,
173+
type=core.VarDesc.VarType.LOD_TENSOR,
174+
persistable=False,
175+
stop_gradient=Out_var.stop_gradient)
176+
177+
check_variable_and_dtype(
178+
Out_var, 'tensor',
179+
['float16', 'float32', 'float64', 'int32', 'int64'],
180+
'c_allreduce_sum')
181+
182+
dst_block.append_op(
183+
type='c_embedding',
184+
inputs={'Ids': [Ids_var],
185+
'W': [Weight_var]},
186+
outputs={'Out': [intermediate_var_0]},
187+
attrs={"start_index": relative_idx})
188+
189+
# use_model_parallel
190+
dst_block.append_op(
191+
type='c_allreduce_sum',
192+
inputs={'X': [intermediate_var_0]},
193+
outputs={'Out': [Out_var]},
194+
attrs={
195+
'ring_id': group.id,
196+
'use_calc_stream': True,
197+
'use_model_parallel': True,
198+
})
199+
200+
if in_dygraph_mode():
201+
raise NotImplementedError(
202+
"Dist op for [{}] with idx [{}] is NOT implemented yet.".format(
203+
"matmul", 0))
204+
else:
205+
return static_handle
206+
95207

96208
register_distributed_operator_impl("lookup_table_v2",
97209
DistributedEmbeddingImpl("row_parallel"))

0 commit comments

Comments
 (0)