Skip to content

Commit 2fa77ad

Browse files
authored
[DistDialect] Add pir nd_mesh reshard function (#64223)
* add nd_mesh_reshard_function for pir, partial_to_replicate part * add s_to_r part in nd_mesh_reshard_func * add infer local for all_gather op * formalize the unit test and remove print code * revert s_to_r unit test * add cross_mesh unit test * add pir_reshard to CI * add pir_reshard to CI * fix code style
1 parent 06ac6ab commit 2fa77ad

10 files changed

Lines changed: 907 additions & 1 deletion

File tree

python/paddle/distributed/auto_parallel/static/reshard_funcs/base_reshard_func.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import paddle
16+
1517
# all registered reshard functions
1618
_g_reshard_func_list = []
1719

@@ -64,3 +66,23 @@ def is_replicated(dist_attr):
6466
):
6567
return True
6668
return False
69+
70+
71+
def copy_dist_attr_with_new_member(
72+
src_dist_attr,
73+
new_process_mesh=None,
74+
new_dims_mapping=None,
75+
new_partial_status=None,
76+
):
77+
if new_process_mesh is None:
78+
new_process_mesh = src_dist_attr.process_mesh
79+
if new_dims_mapping is None:
80+
new_dims_mapping = src_dist_attr.dims_mapping
81+
if new_partial_status is None:
82+
new_partial_status = src_dist_attr.partial_status
83+
84+
return paddle.base.libpaddle.pir.create_tensor_dist_attribute(
85+
new_process_mesh,
86+
new_dims_mapping,
87+
new_partial_status,
88+
)
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
17+
import paddle
18+
import paddle.distributed as dist
19+
20+
from .base_reshard_func import (
21+
ReshardFunction,
22+
copy_dist_attr_with_new_member,
23+
is_partial,
24+
)
25+
from .p_to_r_reshard_func import PToRReshardFunction
26+
from .r_to_s_reshard_func import RToSReshardFunction
27+
from .s_to_r_reshard_func import SToRReshardFunction
28+
from .same_status_reshard_func import SameStatusReshardFunction
29+
30+
31+
def find_first_diff_shard_axis(src_dist_attr, dst_dist_attr):
32+
src_dims_mapping = src_dist_attr.dims_mapping
33+
dst_dims_mapping = dst_dist_attr.dims_mapping
34+
ndim = len(src_dims_mapping)
35+
for i in range(ndim - 1, -1, -1):
36+
if src_dims_mapping[i] != dst_dims_mapping[i]:
37+
return i
38+
return -1
39+
40+
41+
def get_1D_sub_process_mesh(process_mesh, mesh_dim):
42+
"""
43+
Get the 1-D sub process mesh on specific mesh_dim which:
44+
1) where the reshard should be performed.
45+
2) contains current process.
46+
47+
Args:
48+
process_mesh (ProcessMesh): the global process mesh.
49+
mesh_dim (int): the mesh dimension where the dist_tensor is
50+
sharded or partial.
51+
52+
e.g.
53+
1) process_mesh = [[0, 1, 2], [3, 4, 5]], axis = 0:
54+
process rank id returned sub mesh
55+
0 or 3 [0, 3]
56+
1 or 4 [1, 4]
57+
2 or 5 [2, 5]
58+
2) process_mesh = [[0, 1, 2], [3, 4, 5]], axis = 1:
59+
process rank id returned sub mesh
60+
0 or 1 or 2 [0, 1, 2]
61+
3 or 4 or 5 [3, 4, 5]
62+
"""
63+
import numpy as np
64+
65+
mesh_shape = process_mesh.shape
66+
dim_names = process_mesh.dim_names
67+
process_ids = np.array(process_mesh.process_ids).reshape(mesh_shape)
68+
69+
rank_id = dist.get_rank()
70+
coord = list(np.where(process_ids == rank_id))
71+
coord[mesh_dim] = range(mesh_shape[mesh_dim])
72+
sub_process_ids = process_ids[tuple(coord)].flatten()
73+
sub_mesh_shape = sub_process_ids.shape
74+
sub_mesh_name = dim_names[mesh_dim]
75+
76+
return dist.ProcessMesh(sub_process_ids, [sub_mesh_name])
77+
78+
79+
class NdMeshReshardFunction(ReshardFunction):
80+
def is_suitable(self, src_dist_attr, dst_dist_attr):
81+
in_mesh = src_dist_attr.process_mesh
82+
out_mesh = dst_dist_attr.process_mesh
83+
84+
if in_mesh != out_mesh:
85+
return False
86+
if out_mesh.ndim <= 1:
87+
return False
88+
# check dims_mapping and partial_status
89+
if src_dist_attr == dst_dist_attr:
90+
return False
91+
92+
return True
93+
94+
def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
95+
"""
96+
Reshard on N-d mesh:
97+
1. Find the tensor dimensions where the dims_mapping values
98+
differ between src_dist_attr and dst_dist_attr.
99+
2. From higher to lower, convert the non-replicated dimensions
100+
in step1 to replicated using corresponding 1-D mesh functions.
101+
3. Convert the replicated dimensions in step2 to the status in
102+
dst_dist_attr with corresponding 1-D mesh functions.
103+
"""
104+
# Step1. find first dimension with different shard status in src_dist_attr
105+
# and dst_dist_attr.
106+
first_diff_axis = find_first_diff_shard_axis(
107+
src_dist_attr, dst_dist_attr
108+
)
109+
ori_dst_dist_attr = copy_dist_attr_with_new_member(dst_dist_attr)
110+
out_value = src_value # intermediate result
111+
src_type = src_value.type()
112+
tensor_ndim = len(src_value.shape)
113+
process_mesh = dst_dist_attr.process_mesh
114+
115+
# Step2. Convert the non-replicated dimensions to replicated.
116+
# Step2.1. convert partial status to replicated
117+
real_out_dist_attr = copy_dist_attr_with_new_member(src_dist_attr)
118+
if is_partial(src_dist_attr):
119+
in_partial_status = copy.deepcopy(src_dist_attr.partial_status)
120+
out_partial_status = dst_dist_attr.partial_status # read-only
121+
# convert each partial dim to replicated with corresponding
122+
# 1-D mesh function
123+
for partial_dim, partial_type in in_partial_status.items():
124+
if (
125+
partial_dim in out_partial_status
126+
or ori_dst_dist_attr.dims_mapping[partial_dim] > -1
127+
):
128+
continue
129+
130+
# get the partial status after converting
131+
real_out_partial_status = copy.deepcopy(
132+
real_out_dist_attr.partial_status
133+
)
134+
real_out_partial_status.pop(partial_dim)
135+
real_out_dist_attr = copy_dist_attr_with_new_member(
136+
real_out_dist_attr,
137+
new_partial_status=real_out_partial_status,
138+
)
139+
140+
# get the process_mesh on specific axis
141+
sub_mesh = get_1D_sub_process_mesh(process_mesh, partial_dim)
142+
143+
# calculate corresponding 1-D dist_attr of src_dst_attr
144+
in_one_dim_partial_status = {0: partial_type}
145+
in_one_dim_dist_attr = (
146+
paddle.base.libpaddle.pir.create_tensor_dist_attribute(
147+
sub_mesh,
148+
[-1] * tensor_ndim,
149+
in_one_dim_partial_status,
150+
)
151+
)
152+
153+
# calculate corresponding 1-D dist_attr of dst_dst_attr
154+
out_one_dim_dist_attr = (
155+
paddle.base.libpaddle.pir.create_tensor_dist_attribute(
156+
sub_mesh,
157+
[-1] * tensor_ndim,
158+
{},
159+
)
160+
)
161+
162+
one_dim_func = PToRReshardFunction()
163+
out_value = one_dim_func.reshard(
164+
in_one_dim_dist_attr,
165+
out_one_dim_dist_attr,
166+
out_value,
167+
src_type,
168+
)
169+
170+
out_value.update_dist_attr(real_out_dist_attr)
171+
172+
# Step2.2 convert shard status to replicated
173+
for i in range(first_diff_axis, -1, -1):
174+
in_mesh_axis = real_out_dist_attr.dims_mapping[i]
175+
if in_mesh_axis == -1:
176+
continue
177+
178+
# calculate the dist_attr after converting
179+
real_out_dims_mapping = copy.deepcopy(
180+
real_out_dist_attr.dims_mapping
181+
)
182+
real_out_dims_mapping[i] = -1
183+
real_out_dist_attr = copy_dist_attr_with_new_member(
184+
real_out_dist_attr, new_dims_mapping=real_out_dims_mapping
185+
)
186+
187+
# get the process_mesh on specific axis
188+
sub_mesh = get_1D_sub_process_mesh(process_mesh, in_mesh_axis)
189+
190+
# calculate corresponding 1-D dist_attr of src_dst_attr
191+
in_one_dim_dims_mapping = [-1] * tensor_ndim
192+
in_one_dim_dims_mapping[i] = 0
193+
in_one_dim_dist_attr = (
194+
paddle.base.libpaddle.pir.create_tensor_dist_attribute(
195+
sub_mesh, in_one_dim_dims_mapping, {}
196+
)
197+
)
198+
199+
# calculate corresponding 1-D dist_attr of dst_dst_attr
200+
out_one_dim_dims_mapping = [-1] * tensor_ndim
201+
out_one_dim_dist_attr = (
202+
paddle.base.libpaddle.pir.create_tensor_dist_attribute(
203+
sub_mesh, out_one_dim_dims_mapping, {}
204+
)
205+
)
206+
207+
one_dim_func = SToRReshardFunction()
208+
out_value = one_dim_func.reshard(
209+
in_one_dim_dist_attr, out_one_dim_dist_attr, out_value, src_type
210+
)
211+
212+
out_value.update_dist_attr(real_out_dist_attr)
213+
214+
# Step3. Convert the replicated status to the status in dst_dist_attr
215+
# Step3.1 convert replicated to partial
216+
if is_partial(ori_dst_dist_attr):
217+
in_partial_status = out_value.dist_attr.partial_status
218+
out_partial_status = ori_dst_dist_attr.partial_status
219+
for partial_dim, partial_type in out_partial_status.items():
220+
if partial_dim in in_partial_status:
221+
continue
222+
223+
raise NotImplementedError(
224+
"RToPReshardFunction is not implemented"
225+
)
226+
227+
# Step3.2 convert replicated/partial to shard
228+
for i in range(first_diff_axis, -1, -1):
229+
out_mesh_axis = ori_dst_dist_attr.dims_mapping[i]
230+
if out_mesh_axis == -1:
231+
continue
232+
in_partial_status = out_value.dist_attr().partial_status
233+
need_p2s = out_mesh_axis in in_partial_status
234+
dims_mapping = copy.deepcopy(real_out_dist_attr.dims_mapping)
235+
dims_mapping[i] = out_mesh_axis
236+
partial_status = None
237+
if out_mesh_axis in real_out_dist_attr.partial_status:
238+
partial_status = copy.deepcopy(
239+
real_out_dist_attr.partial_status
240+
)
241+
partial_status.pop(out_mesh_axis)
242+
243+
real_out_dist_attr = copy_dist_attr_with_new_member(
244+
real_out_dist_attr,
245+
new_dims_mapping=dims_mapping,
246+
new_partial_status=partial_status,
247+
)
248+
249+
# get the process_mesh on specific axis
250+
sub_mesh = get_1D_sub_process_mesh(process_mesh, out_mesh_axis)
251+
252+
# calculate the corresponding 1-D input dist attr
253+
in_one_dim_dims_mapping = [-1] * tensor_ndim
254+
in_one_dim_dist_attr = (
255+
paddle.base.libpaddle.pir.create_tensor_dist_attribute(
256+
sub_mesh, in_one_dim_dims_mapping, {}
257+
)
258+
)
259+
260+
# calculate the corresponding 1-D output dist attr
261+
out_one_dim_dims_mapping = [-1] * tensor_ndim
262+
out_one_dim_dims_mapping[i] = 0
263+
out_one_dim_dist_attr = (
264+
paddle.base.libpaddle.pir.create_tensor_dist_attribute(
265+
sub_mesh, out_one_dim_dims_mapping, {}
266+
)
267+
)
268+
269+
if need_p2s:
270+
raise NotImplementedError(
271+
"PToSReshardFunction is not implemented"
272+
)
273+
else:
274+
one_dim_func = RToSReshardFunction()
275+
out_value = one_dim_func.reshard(
276+
in_one_dim_dist_attr,
277+
out_one_dim_dist_attr,
278+
out_value,
279+
dst_type,
280+
)
281+
out_value.update_dist_attr(real_out_dist_attr)
282+
283+
out_value.set_type(dst_type)
284+
return out_value
285+
286+
287+
class NdMeshReshardFunctionCrossMesh(ReshardFunction):
288+
def is_suitable(self, src_dist_attr, dst_dist_attr):
289+
in_mesh = src_dist_attr.process_mesh
290+
out_mesh = dst_dist_attr.process_mesh
291+
292+
if in_mesh == out_mesh:
293+
return False
294+
if in_mesh.shape != out_mesh.shape:
295+
return False
296+
if out_mesh.ndim <= 1:
297+
return False
298+
if src_dist_attr == dst_dist_attr:
299+
return False
300+
301+
return True
302+
303+
def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
304+
same_status_func = SameStatusReshardFunction()
305+
tmp_dist_attr = paddle.base.libpaddle.pir.create_tensor_dist_attribute(
306+
dst_dist_attr.process_mesh,
307+
src_dist_attr.dims_mapping,
308+
src_dist_attr.partial_status,
309+
)
310+
tmp_dst_type = paddle.base.libpaddle.pir.cvt_to_dist_type(
311+
src_value.type(), tmp_dist_attr
312+
)
313+
out_value = same_status_func.reshard(
314+
src_dist_attr, tmp_dist_attr, src_value, tmp_dst_type
315+
)
316+
317+
if out_value is None:
318+
return None
319+
320+
curr_global_rank = paddle.distributed.get_rank()
321+
if curr_global_rank in dst_dist_attr.process_mesh.process_ids:
322+
nd_mesh_func = NdMeshReshardFunction()
323+
assert nd_mesh_func.is_suitable(
324+
tmp_dist_attr, dst_dist_attr
325+
), f"Invoke the p to r reshard function is not valid from {tmp_dist_attr} to {dst_dist_attr}"
326+
return nd_mesh_func.reshard(
327+
tmp_dist_attr, dst_dist_attr, out_value, dst_type
328+
)
329+
return None

python/paddle/distributed/auto_parallel/static/reshard_funcs/reshard_func_register.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
# limitations under the License.
1414

1515
from .base_reshard_func import register_reshard_func
16+
from .nd_mesh_reshard_func import (
17+
NdMeshReshardFunction,
18+
NdMeshReshardFunctionCrossMesh,
19+
)
1620
from .p_to_r_reshard_func import (
1721
PToRReshardFunction,
1822
PToRReshardFunctionCrossMesh,
@@ -36,6 +40,8 @@ def register_reshard_funcs():
3640
register_reshard_func(SameStatusReshardFunction())
3741
register_reshard_func(SToRReshardFunction())
3842
register_reshard_func(SToRReshardFunctionCrossMesh())
43+
register_reshard_func(NdMeshReshardFunction())
44+
register_reshard_func(NdMeshReshardFunctionCrossMesh())
3945

4046

4147
register_reshard_funcs()

0 commit comments

Comments
 (0)