Skip to content

Commit d9c85d8

Browse files
committed
add cross mesh r_to_s reshard func for auto parrallel.
1 parent 0e58487 commit d9c85d8

14 files changed

Lines changed: 357 additions & 185 deletions

paddle/fluid/pybind/dist_api.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h"
1919
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
20+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
2021
#include "paddle/fluid/pir/dialect/distributed/transforms/dist_to_dense_pass.h"
2122
#include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h"
2223
#include "paddle/fluid/pybind/dist_api.h"
@@ -122,6 +123,7 @@ OperationDistAttribute CreateOperationDistAttribute(
122123
void BindDistUtils(pybind11::module *m) {
123124
m->def("create_tensor_dist_attribute", CreateTensorDistAttribute);
124125
m->def("create_op_dist_attribute", CreateOperationDistAttribute);
126+
m->def("cvt_to_dist_type", &dialect::CvtToPirDistType);
125127
}
126128

127129
void BindDistPassAPI(pybind11::module *module) {

paddle/fluid/pybind/pir.cc

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "paddle/fluid/ir_adaptor/translator/utils.h"
3434
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
3535
#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h"
36+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
3637
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
3738
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
3839
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
@@ -815,45 +816,43 @@ py::str Value2String(Value self) {
815816
return print_stream.str();
816817
}
817818

818-
phi::DataType GetValueDtype(Value value) {
819-
if (!value.type()) {
819+
phi::DataType GetTensorDtype(Type type) {
820+
if (!type) {
820821
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
821822
}
822-
if (value.type().isa<DenseTensorType>()) {
823-
return paddle::dialect::TransToPhiDataType(
824-
value.type().dyn_cast<DenseTensorType>().dtype());
825-
} else if (value.type().isa<SelectedRowsType>()) {
826-
return paddle::dialect::TransToPhiDataType(
827-
value.type().dyn_cast<SelectedRowsType>().dtype());
828-
} else if (value.type().isa<DenseTensorArrayType>()) {
829-
return paddle::dialect::TransToPhiDataType(
830-
value.type().dyn_cast<DenseTensorArrayType>().dtype());
831-
} else if (value.type().isa<DistDenseTensorType>()) {
832-
return paddle::dialect::TransToPhiDataType(
833-
value.type().dyn_cast<DistDenseTensorType>().dtype());
823+
if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) {
824+
return dialect::TransToPhiDataType(dense_tensor_type.dtype());
825+
} else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) {
826+
return dialect::TransToPhiDataType(select_rows.dtype());
827+
} else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) {
828+
return dialect::TransToPhiDataType(dense_array.dtype());
834829
} else {
835830
PADDLE_THROW(phi::errors::InvalidArgument(
836831
"Currently, we can only get phi::DataType from DenseTensorType and "
837-
"SelectedRowsType, DistDenseTensorType."));
832+
"SelectedRowsType, DenseTensorArrayType."));
838833
}
839834
}
835+
phi::DataType GetValueDtype(Value value) {
836+
return GetTensorDtype(value.type());
837+
}
840838

841-
const phi::DDim &GetValueDims(Value value) {
842-
if (!value.type()) {
843-
PADDLE_THROW(phi::errors::InvalidArgument("The type of value is nullptr."));
839+
const phi::DDim &GetTensorDims(Type type) {
840+
if (!type) {
841+
PADDLE_THROW(common::errors::InvalidArgument(
842+
"The type used to get dims is nullptr."));
844843
}
845-
if (value.type().isa<DenseTensorType>()) {
846-
return value.type().dyn_cast<DenseTensorType>().dims();
847-
} else if (value.type().isa<SelectedRowsType>()) {
848-
return value.type().dyn_cast<SelectedRowsType>().dims();
849-
} else if (value.type().isa<DistDenseTensorType>()) {
850-
return value.type().dyn_cast<DistDenseTensorType>().global_ddim();
844+
if (auto dense_type = type.dyn_cast<DenseTensorType>()) {
845+
return dense_type.dims();
846+
} else if (auto select_rows_type = type.dyn_cast<SelectedRowsType>()) {
847+
return select_rows_type.dims();
851848
} else {
852-
PADDLE_THROW(phi::errors::InvalidArgument(
853-
"Currently, we can only get shape for dense and distdense"
854-
"tensor."));
849+
PADDLE_THROW(common::errors::InvalidArgument(
850+
"Currently, we can only get shape for dense and selsect rows type."));
855851
}
856852
}
853+
const phi::DDim &GetValueDims(Value value) {
854+
return GetTensorDims(value.type());
855+
}
857856

858857
pir::Value apply(Value self, py::object func) {
859858
py::gil_scoped_acquire gil;
@@ -1100,13 +1099,10 @@ void BindValue(py::module *m) {
11001099
}
11011100
return self.type().dyn_cast<DistTypeInterface>().tensor_dist_attr();
11021101
})
1102+
// The function will calculate the new local shape based on the global
1103+
// shape and the dist_attr argument.
11031104
.def("update_dist_attr", [](Value &self, TensorDistAttribute dist_attr) {
1104-
if (auto dist_type = self.type().dyn_cast<DistTypeInterface>()) {
1105-
self.set_type(dist_type.CopyWithNewDistAttr(dist_attr));
1106-
} else {
1107-
PADDLE_THROW(common::errors::InvalidArgument(
1108-
"update_dist_attr is only for dist type tensor."));
1109-
}
1105+
self.set_type(dialect::CvtToPirDistType(self.type(), dist_attr));
11101106
});
11111107
}
11121108

@@ -1137,11 +1133,26 @@ bool GetValueBoolAttr(Value value, const std::string &attr_name) {
11371133

11381134
void BindType(py::module *m) {
11391135
py::class_<Type> ir_type(*m, "Type");
1140-
ir_type.def("__eq__", &Type::operator==).def("__str__", [](Type &self) {
1141-
std::ostringstream print_stream;
1142-
print_stream << self;
1143-
return print_stream.str();
1144-
});
1136+
ir_type.def("__eq__", &Type::operator==)
1137+
.def_property(
1138+
"shape",
1139+
[](Type self) { return phi::vectorize(GetTensorDims(self)); },
1140+
[](Type self, const std::vector<int> &shape) {
1141+
PADDLE_THROW(phi::errors::InvalidArgument(
1142+
"can't set shape when building static graph"));
1143+
})
1144+
.def_property(
1145+
"dtype",
1146+
[](Type self) { return GetTensorDtype(self); },
1147+
[](Type self, phi::DataType dtype) {
1148+
PADDLE_THROW(phi::errors::InvalidArgument(
1149+
"can't set dtype when building static graph"));
1150+
})
1151+
.def("__str__", [](Type &self) {
1152+
std::ostringstream print_stream;
1153+
print_stream << self;
1154+
return print_stream.str();
1155+
});
11451156

11461157
m->def("create_shaped_type",
11471158
[](Type &type, const std::vector<int> &shape) -> Type {

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,12 +672,12 @@ def _parallel_pir(self, mode):
672672
# TODO(JZ-LIANG) Step 3.1: Partition Pass
673673
# insert reshard op if operand tensor's placements if different from what the cumsumer op need.
674674
# Partition the computation graph into different pipeline stage if need.
675-
dist_program = apply_partition_pass(dist_program)
675+
apply_partition_pass(dist_program)
676676

677677
# TODO(hitywt) Step 3.2: Reshard Pass
678678
# resolute the reshard op into special collective operation.
679679
# collect the communicator created during resolution.
680-
dist_program = apply_reshard_pass(dist_program)
680+
apply_reshard_pass(dist_program)
681681

682682
# Part 4: Optimization Pass
683683
# NOTE Only those Optimization Pass that related to Parallelism (need dist attr) should be placed here and all the Pass should be Optional.

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 79 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -61,99 +61,90 @@ def reshard_combine_value(op, operand, attr):
6161

6262

6363
def apply_partition_pass(program):
64-
with paddle.static.program_guard(program):
65-
for op in program.global_block().ops:
66-
if op.name() in partition_skip_op_list:
67-
continue
68-
assert len(op.operands()) == len(
69-
op.dist_attr.operands()
70-
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}"
71-
72-
for operand, attr in zip(op.operands(), op.dist_attr.operands()):
73-
prev_var = operand.source()
74-
if prev_var.is_combine():
75-
operand.set_source(reshard_combine_value(op, operand, attr))
76-
else:
77-
operand.set_source(reshard_single_value(op, operand, attr))
78-
prev_op = prev_var.get_defining_op()
79-
if (
80-
prev_op
81-
and prev_op.num_results() == 1
82-
and prev_var.use_empty()
83-
):
84-
prev_op.erase()
85-
86-
for var, attr in zip(op.results(), op.dist_attr.results()):
87-
if (
88-
var.initialized()
89-
and var.is_dist()
90-
and var.dist_attr() != attr
91-
):
92-
paddle.pir.set_insertion_point_after(op)
93-
old_dist_attr = var.dist_attr()
94-
var.update_dist_attr(attr.as_tensor_dist_attr())
95-
# insert reshard
96-
reshard_var = paddle._C_ops.reshard_v2(var, old_dist_attr)
97-
var.replace_all_uses_with(reshard_var)
98-
reshard_var.get_defining_op().operand(0).set_source(var)
99-
100-
# pruning op and value not belong to cur rank
101-
cur_rank = paddle.distributed.get_rank()
102-
for op in program.global_block().ops[::-1]:
103-
if cur_rank not in op.dist_attr.process_mesh.process_ids:
104-
program.global_block().remove_op(op)
64+
for op in program.global_block().ops:
65+
if op.name() in partition_skip_op_list:
66+
continue
67+
assert len(op.operands()) == len(
68+
op.dist_attr.operands()
69+
), f"The number of operands and the number of op_dist_attr's operands are not equal in op: {op}"
70+
71+
for operand, attr in zip(op.operands(), op.dist_attr.operands()):
72+
prev_var = operand.source()
73+
if prev_var.is_combine():
74+
operand.set_source(reshard_combine_value(op, operand, attr))
10575
else:
106-
# set the operand as null when it is not belong to cur rank
107-
if (
108-
op.name() == 'dist_op.reshard'
109-
and cur_rank
110-
not in op.operand(0)
111-
.source()
112-
.dist_attr()
113-
.process_mesh.process_ids
114-
):
115-
op.operand(0).set_source(None)
116-
117-
# merge pd.data ops for
118-
lr_ops = []
119-
for op in program.global_block().ops[::-1]:
76+
operand.set_source(reshard_single_value(op, operand, attr))
77+
prev_op = prev_var.get_defining_op()
78+
if prev_op and prev_op.num_results() == 1 and prev_var.use_empty():
79+
prev_op.erase()
80+
81+
for var, attr in zip(op.results(), op.dist_attr.results()):
82+
if var.initialized() and var.is_dist() and var.dist_attr() != attr:
83+
paddle.pir.set_insertion_point_after(op)
84+
old_dist_attr = var.dist_attr()
85+
var.update_dist_attr(attr.as_tensor_dist_attr())
86+
# insert reshard
87+
reshard_var = paddle._C_ops.reshard_v2(var, old_dist_attr)
88+
var.replace_all_uses_with(reshard_var)
89+
reshard_var.get_defining_op().operand(0).set_source(var)
90+
91+
# pruning op and value not belong to cur rank
92+
cur_rank = paddle.distributed.get_rank()
93+
for op in program.global_block().ops[::-1]:
94+
if cur_rank not in op.dist_attr.process_mesh.process_ids:
95+
op.erase()
96+
else:
97+
# set the operand as null when it is not belong to cur rank
12098
if (
121-
op.name() == 'pd_op.data'
122-
and "learning_rate" in op.attrs()["name"]
99+
op.name() == 'dist_op.reshard'
100+
and cur_rank
101+
not in op.operand(0)
102+
.source()
103+
.dist_attr()
104+
.process_mesh.process_ids
123105
):
124-
lr_ops.append(op)
125-
126-
if len(lr_ops) > 1:
127-
lr_value = lr_ops[0].result(0)
128-
for op in lr_ops[1:]:
129-
lr = op.result(0)
130-
lr.replace_all_uses_with(lr_value)
131-
program.global_block().remove_op(op)
132-
return program
106+
op.operand(0).set_source(None)
107+
108+
# merge pd.data ops for
109+
lr_ops = []
110+
for op in program.global_block().ops[::-1]:
111+
if op.name() == 'pd_op.data' and "learning_rate" in op.attrs()["name"]:
112+
lr_ops.append(op)
113+
114+
if len(lr_ops) > 1:
115+
lr_value = lr_ops[0].result(0)
116+
for op in lr_ops[1:]:
117+
lr = op.result(0)
118+
lr.replace_all_uses_with(lr_value)
119+
op.erase()
133120

134121

135122
def apply_reshard_pass(program):
136-
new_program = program.clone()
137-
with paddle.base.program_guard(new_program):
138-
for op in new_program.global_block().ops:
139-
if op.name() == 'dist_op.reshard':
140-
var = op.operand_source(0)
141-
op_dist_attr = op.dist_attr
142-
src_dist_attr = op_dist_attr.operand(0).as_tensor_dist_attr()
143-
dst_dist_attr = op_dist_attr.result(0).as_tensor_dist_attr()
144-
assert (
145-
not var.initialized() or var.dist_attr() == src_dist_attr
146-
), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}"
147-
148-
reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr)
149-
assert (
150-
reshard_func is not None
151-
), f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}'
152-
reshard_func.reshard(
153-
new_program, op, src_dist_attr, dst_dist_attr
154-
)
155-
156-
return new_program
123+
for op in program.global_block().ops:
124+
if op.name() == 'dist_op.reshard':
125+
var = op.operand_source(0)
126+
op_dist_attr = op.dist_attr
127+
src_dist_attr = op_dist_attr.operand(0).as_tensor_dist_attr()
128+
dst_dist_attr = op_dist_attr.result(0).as_tensor_dist_attr()
129+
assert (
130+
not var.initialized() or var.dist_attr() == src_dist_attr
131+
), f"The dist_attr of reshard op's input and operand should be equal, but got {var.dist_attr()} and {src_dist_attr}"
132+
133+
reshard_func = choose_reshard_func(src_dist_attr, dst_dist_attr)
134+
assert (
135+
reshard_func is not None
136+
), f'There is no reshard function that matches src_dist_attr: {src_dist_attr} and dst_dist_attr: {dst_dist_attr}'
137+
paddle.pir.set_insertion_point_after(op)
138+
out_value = reshard_func.reshard(
139+
src_dist_attr,
140+
dst_dist_attr,
141+
op.operand_source(0),
142+
op.result(0).type(),
143+
)
144+
if out_value is not None:
145+
op.result(0).replace_all_uses_with(out_value)
146+
if op.result(0).use_empty():
147+
op.erase()
157148

158149

159150
# In sequence_parallel, we need to transpose hidden_states
@@ -183,5 +174,5 @@ def eliminate_transpose_by_reshape(program):
183174
transpose_var = op.result(0)
184175
reshape_var = paddle._C_ops.reshape(var, transpose_var.shape)
185176
transpose_var.replace_all_uses_with(reshape_var)
186-
program.global_block().remove_op(op)
177+
op.erase()
187178
return program

python/paddle/distributed/auto_parallel/static/reshard_funcs/base_reshard_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class ReshardFunction:
2020
def is_suitable(self, dist_tensor, dist_attr):
2121
raise NotImplementedError
2222

23-
def reshard(self, program, op, src_tensor, dst_dist_attr):
23+
def reshard(self, src_dist_attr, dst_dist_attr, src_value, dst_type):
2424
raise NotImplementedError
2525

2626

0 commit comments

Comments
 (0)