Skip to content

Commit 1048206

Browse files
authored
add global to sub mesh reshard func or auto parallal. (#64418)
1 parent 8768654 commit 1048206

10 files changed

Lines changed: 89 additions & 0 deletions

File tree

paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,10 @@ def gen_infermeta_func_str(args, op_info):
937937
spmd_params = op_info.kernel_map['param']
938938
else:
939939
spmd_params = op_info.input_name_list
940+
# TODO(GhostScreaming): specialized case for reshape_grad
941+
# xshape is not kernel params, but inferspmd needs it.
942+
if "reshape_grad" in op_info.kernel_map['func'][0]:
943+
spmd_params = ["xshape"] + spmd_params
940944
op_info.spmd_params = spmd_params
941945

942946
infermeta_inputs_str = get_infermeta_inputs_str(

paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@
701701
infer_meta :
702702
func : KernelWithXShapeInferMeta
703703
param : [xshape, out_grad]
704+
spmd_rule: StaticReshapeGradInferSpmd
704705
kernel :
705706
func : reshape_grad
706707
param : [out_grad]

paddle/fluid/pybind/dist_api.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.h"
2222
#include "paddle/fluid/pybind/dist_api.h"
2323
#include "paddle/fluid/pybind/dist_static_op_function.h"
24+
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
2425
#include "paddle/phi/core/enforce.h"
2526

2627
namespace py = pybind11;
@@ -40,6 +41,7 @@ struct type_caster<paddle::flat_hash_map<Key, Value, Hash, Equal, Alloc>>
4041
} // namespace pybind11
4142

4243
using paddle::dialect::OperationDistAttribute;
44+
using paddle::dialect::ProcessMeshAttribute;
4345
using paddle::dialect::TensorDistAttribute;
4446

4547
namespace paddle {
@@ -122,6 +124,7 @@ OperationDistAttribute CreateOperationDistAttribute(
122124
void BindDistUtils(pybind11::module *m) {
123125
m->def("create_tensor_dist_attribute", CreateTensorDistAttribute);
124126
m->def("create_op_dist_attribute", CreateOperationDistAttribute);
127+
m->def("get_sub_meshes", phi::distributed::GetSubMeshes);
125128
m->def("cvt_to_dist_type", &dialect::CvtToPirDistType);
126129
}
127130

paddle/fluid/pybind/pir.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,14 @@ void BindValue(py::module *m) {
11211121
return_value_policy::reference)
11221122
.def("numel", [](Value self) { return phi::product(GetValueDims(self)); })
11231123
.def("type", &Value::type)
1124+
.def("index",
1125+
[](Value self) -> uint32_t {
1126+
if (auto op_result = self.dyn_cast<OpResult>()) {
1127+
return op_result.index();
1128+
}
1129+
PADDLE_THROW(phi::errors::InvalidArgument(
1130+
"only support accesss index from op_result."));
1131+
})
11241132
.def("is_dense_tensor_type",
11251133
[](Value self) { return self.type().isa<DenseTensorType>(); })
11261134
.def("is_selected_row_type",

paddle/phi/infermeta/spmd_rules/reshape.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,5 +351,12 @@ SpmdInfo ReshapeGradInferSpmd(const DistMetaTensor& x_shape,
351351
return {{out_grad_dist_dst}, {x_shape_dist_dst}};
352352
}
353353

354+
SpmdInfo StaticReshapeGradInferSpmd(const DistMetaTensor& x_shape,
355+
const DistMetaTensor& out_grad) {
356+
auto spmd_info = ReshapeGradInferSpmd(x_shape, out_grad);
357+
spmd_info.first.insert(spmd_info.first.begin(), x_shape.dist_attr());
358+
return spmd_info;
359+
}
360+
354361
} // namespace distributed
355362
} // namespace phi

paddle/phi/infermeta/spmd_rules/reshape.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,8 @@ SpmdInfo ReshapeInferSpmdDynamic(const DistMetaTensor& x,
3535
SpmdInfo ReshapeGradInferSpmd(const DistMetaTensor& x_shape,
3636
const DistMetaTensor& out_grad);
3737

38+
SpmdInfo StaticReshapeGradInferSpmd(const DistMetaTensor& x_shape,
39+
const DistMetaTensor& out_grad);
40+
3841
} // namespace distributed
3942
} // namespace phi

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ def apply_reshard_pass(program):
132132
not var.initialized() or var.dist_attr() == src_dist_attr
133133
), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}"
134134

135+
if src_dist_attr == dst_dist_attr:
136+
op.result(0).replace_all_uses_with(var)
137+
op.erase()
138+
continue
135139
reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr)
136140
assert (
137141
reshard_func is not None
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
17+
from .base_reshard_func import ReshardFunction
18+
19+
20+
class GlobaleToSubMeshFunction(ReshardFunction):
21+
def is_suitable(self, src_dist_attr, dst_dist_attr):
22+
if 0 in src_dist_attr.dims_mapping or 0 in src_dist_attr.partial_status:
23+
return False
24+
in_mesh = src_dist_attr.process_mesh
25+
out_mesh = dst_dist_attr.process_mesh
26+
if in_mesh.ndim != out_mesh.ndim + 1:
27+
return False
28+
sub_meshes = paddle.base.libpaddle.pir.get_sub_meshes(in_mesh)
29+
return out_mesh in sub_meshes
30+
31+
def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
32+
if src_value.has_one_use():
33+
src_value.update_dist_attr(dst_dist_attr)
34+
prev_op = src_value.get_defining_op()
35+
op_dist_attr = prev_op.dist_attr
36+
op_mesh = op_dist_attr.process_mesh
37+
operands = op_dist_attr.operands()
38+
results = op_dist_attr.results()
39+
results[src_value.index()] = dst_dist_attr
40+
prev_op.dist_attr = (
41+
paddle.base.libpaddle.pir.create_op_dist_attribute(
42+
op_mesh, operands, results
43+
)
44+
)
45+
return src_value
46+
else:
47+
dst_value = paddle._C_ops.share_data_(src_value)
48+
share_data_op = dst_value.get_defining_op()
49+
# set dist type and dist attr
50+
dst_value.set_type(dst_type)
51+
share_data_op.dist_attr = (
52+
paddle.base.libpaddle.pir.create_op_dist_attribute(
53+
dst_dist_attr.process_mesh, [src_dist_attr], [dst_dist_attr]
54+
)
55+
)
56+
return dst_value

python/paddle/distributed/auto_parallel/static/reshard_funcs/reshard_func_register.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from .base_reshard_func import register_reshard_func
16+
from .global_to_sub_mesh_func import GlobaleToSubMeshFunction
1617
from .nd_mesh_reshard_func import (
1718
NdMeshReshardFunction,
1819
NdMeshReshardFunctionCrossMesh,
@@ -42,6 +43,7 @@ def register_reshard_funcs():
4243
register_reshard_func(SToRReshardFunctionCrossMesh())
4344
register_reshard_func(NdMeshReshardFunction())
4445
register_reshard_func(NdMeshReshardFunctionCrossMesh())
46+
register_reshard_func(GlobaleToSubMeshFunction())
4547

4648

4749
register_reshard_funcs()

test/dygraph_to_static/test_tensor_attr_consistency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
'all_used_ops',
9292
'append',
9393
'first_use',
94+
'index',
9495
'get_defining_op',
9596
'has_one_use',
9697
'has_name',

0 commit comments

Comments
 (0)