Skip to content

Commit 70cc347

Browse files
authored
[pir+auto parallel] add reshard op for input when needed (#63072)
* add reshard op for input when needed * fix unary grad inferspmd
1 parent c6891f0 commit 70cc347

9 files changed

Lines changed: 86 additions & 13 deletions

File tree

paddle/fluid/pir/dialect/distributed/ir/dist_api.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,12 @@ pir::Value reshard(const pir::Value& x,
5959
return reshard_op.result(0);
6060
}
6161

62+
pir::Value reshard(const pir::Value& x,
63+
const TensorDistAttribute& tensor_dist_attr) {
64+
auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build<ReShardOp>(
65+
x, tensor_dist_attr);
66+
return reshard_op.result(0);
67+
}
68+
6269
} // namespace dialect
6370
} // namespace paddle

paddle/fluid/pir/dialect/distributed/ir/dist_api.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <vector>
1818

19+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
1920
#include "paddle/phi/common/data_type.h"
2021
#include "paddle/phi/common/place.h"
2122
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
@@ -31,5 +32,9 @@ pir::Value shard_tensor(const pir::Value& x,
3132
pir::Value reshard(const pir::Value& x,
3233
const phi::distributed::ProcessMesh& process_mesh,
3334
const std::vector<int64_t>& dims_mapping);
35+
36+
pir::Value reshard(const pir::Value& x,
37+
const TensorDistAttribute& tensor_dist_attr);
38+
3439
} // namespace dialect
3540
} // namespace paddle

paddle/fluid/pir/dialect/distributed/ir/dist_type.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@ common::DDim InferLocalDDim(const common::DDim& global_ddim,
4343
TensorDistAttribute dist_attr) {
4444
auto& mesh_dim = dist_attr.process_mesh_attr().shape();
4545
auto& dim_mapping = dist_attr.dims_mapping();
46-
PADDLE_ENFORCE_EQ(
47-
global_ddim.size(),
48-
dim_mapping.size(),
49-
::common::errors::PreconditionNotMet(
50-
"The global ddim size must equal to dim_mapping's size!"));
46+
PADDLE_ENFORCE_EQ(global_ddim.size(),
47+
dim_mapping.size(),
48+
::common::errors::PreconditionNotMet(
49+
"The global ddim size must equal to dim_mapping's "
50+
"size, but bot %d vs %d",
51+
global_ddim.size(),
52+
dim_mapping.size()));
5153
common::DDim local_ddim(global_ddim);
5254
for (size_t i = 0; i < dim_mapping.size(); ++i) {
5355
if (dim_mapping[i] != -1) {

paddle/fluid/pybind/dist_api.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <Python.h>
1616
#include "pybind11/stl.h"
1717

18+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h"
1819
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
1920
#include "paddle/fluid/pybind/dist_api.h"
2021
#include "paddle/fluid/pybind/dist_static_op_function.h"
@@ -60,6 +61,10 @@ void BindTensorDistAttribute(py::module *m) {
6061
print_stream << self;
6162
return print_stream.str();
6263
})
64+
.def("__eq__",
65+
[](TensorDistAttribute &self, const TensorDistAttribute &other) {
66+
return self == other;
67+
})
6368
.def_property_readonly("process_mesh",
6469
[](TensorDistAttribute &self) {
6570
return self.process_mesh_attr().process_mesh();
@@ -86,12 +91,20 @@ void BindDistOpsAPI(pybind11::module *module) {
8691
}
8792
}
8893

94+
void BindOpsFunction(py::module *m) {
95+
m->def("reshard_v2",
96+
[](const pir::Value &x, const TensorDistAttribute &dist_attr) {
97+
return reshard(x, dist_attr);
98+
});
99+
}
100+
89101
void BindDistApi(pybind11::module *module) {
90102
auto ir_module = module->def_submodule("pir");
91103
BindOperationDistAttribute(&ir_module);
92104
BindTensorDistAttribute(&ir_module);
93105
auto ops_modules = ir_module.def_submodule("ops");
94106
BindDistOpsAPI(&ops_modules);
107+
BindOpsFunction(&ops_modules);
95108
}
96109

97110
} // namespace pybind

paddle/fluid/pybind/dist_static_op_function.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ static PyMethodDef DistOpsAPI[] = {
8989
(PyCFunction)(void (*)(void))static_api_reshard,
9090
METH_VARARGS | METH_KEYWORDS,
9191
"C++ interface function for reshard."},
92-
9392
{nullptr, nullptr, 0, nullptr}};
9493

9594
} // namespace pybind

paddle/phi/infermeta/spmd_rules/elementwise.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,17 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x,
365365

366366
SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x,
367367
const DistMetaTensor& out_grad) {
368-
return {{out_grad.dist_attr(), out_grad.dist_attr()}, {out_grad.dist_attr()}};
368+
auto dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr());
369+
dist_attr.set_dims_mapping(out_grad.dist_attr().dims_mapping());
370+
return {{dist_attr, dist_attr}, {dist_attr}};
369371
}
370372

371373
SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x,
372374
const DistMetaTensor& out,
373375
const DistMetaTensor& out_grad) {
374-
return {{out_grad.dist_attr(), out_grad.dist_attr(), out_grad.dist_attr()},
375-
{out_grad.dist_attr()}};
376+
auto dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr());
377+
dist_attr.set_dims_mapping(out_grad.dist_attr().dims_mapping());
378+
return {{dist_attr, dist_attr, dist_attr}, {dist_attr}};
376379
}
377380

378381
bool DimsNotEqualOrHasBroadcastDim(const DistMetaTensor& x,

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .dist_saver import DistributedSaver
5555
from .helper import ProgramHelper
5656
from .parallelizer_v2 import Parallelizer
57+
from .pir_pass import apply_partition_pass
5758
from .planner_v2 import Planner
5859
from .process_group import get_all_process_groups, new_process_group
5960

@@ -675,7 +676,7 @@ def _parallel_pir(self, mode):
675676
# TODO(JZ-LIANG) Step 3.1: Partition Pass
676677
# insert reshard op if operand tensor's placements if different from what the cumsumer op need.
677678
# Partition the computation graph into different pipeline stage if need.
678-
# dist_program = apply_partition_pass(dist_program)
679+
dist_program = apply_partition_pass(dist_program)
679680

680681
# TODO(hitywt) Step 3.2: Reshard Pass
681682
# resolute the reshard op into special collective operation.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
17+
18+
def apply_partition_pass(program):
19+
new_program = program.clone()
20+
with paddle.static.program_guard(new_program):
21+
for op in new_program.global_block().ops:
22+
# assert len(op.operands()) == len(op.dist_attr().operand_dist_attrs()), f'The number of operand and operand_dist_attrs are not equal in op: {op}'
23+
for var, operand_dist_attr in zip(
24+
op.operands(), op.dist_attr().operand_dist_attrs()
25+
):
26+
if (
27+
var.source().is_dist_dense_tensor_type()
28+
and var.source().dist_attr() != operand_dist_attr
29+
):
30+
paddle.pir.set_insertion_point(op)
31+
# insert reshard
32+
reshard_var = paddle._pir_ops.reshard_v2(
33+
var.source(), operand_dist_attr
34+
)
35+
var.set_source(reshard_var)
36+
return new_program
37+
38+
39+
def apply_reshard_pass(program):
40+
pass

test/auto_parallel/pir/test_to_static_pir_program.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, mesh):
6666
)
6767

6868
def forward(self, x):
69+
x.stop_gradient = False
6970
out = self.relu_0(x) # triggle backward partial allreduce
7071
out = self.linear_0(out)
7172
out = self.relu_1(out)
@@ -138,6 +139,8 @@ def test_to_static_program(self):
138139
backward_op_list = [
139140
"pd_op.sgd_",
140141
"pd_op.sgd_",
142+
"pd_op.relu_grad",
143+
"dist_op.reshard",
141144
"pd_op.matmul_grad",
142145
"pd_op.relu_grad",
143146
"pd_op.matmul_grad",
@@ -225,10 +228,10 @@ def test_to_static_program(self):
225228
tensor._local_shape, [BATCH_SIZE, CLASS_NUM]
226229
)
227230
elif matmul_grad_idx == 1:
228-
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, 0])
229-
self.assertEqual(tensor.dist_attr().partial_dims, set())
231+
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, -1])
232+
self.assertEqual(tensor.dist_attr().partial_dims, {0})
230233
self.assertEqual(
231-
tensor._local_shape, [BATCH_SIZE, IMAGE_SIZE // 2]
234+
tensor._local_shape, [BATCH_SIZE, IMAGE_SIZE]
232235
)
233236
matmul_grad_idx += 1
234237
if op.name() == 'pd_op.sgd_':

0 commit comments

Comments
 (0)