Skip to content

Commit 1709a16

Browse files
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into load_jit
2 parents d557759 + 766535e commit 1709a16

File tree

26 files changed

+547
-385
lines changed

26 files changed

+547
-385
lines changed

paddle/cinn/hlir/op/contrib/reciprocal.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,53 @@ std::shared_ptr<OpStrategy> StrategyForReciprocal(
125125
return strategy;
126126
}
127127

128+
std::shared_ptr<OpStrategy> StrategyForReciprocalSymbolic(
129+
const framework::NodeAttr &attrs,
130+
const std::vector<ir::Tensor> &inputs,
131+
const std::vector<Type> &out_type,
132+
const std::vector<std::vector<ir::Dim>> &output_shapes,
133+
const Target &target) {
134+
std::string op_name("reciprocal");
135+
136+
framework::CINNCompute reciprocal_compute(
137+
[=](lang::Args args, lang::RetValue *ret) {
138+
CHECK(!args.empty()) << "The input argument of " << op_name
139+
<< " compute is empty! Please check.\n";
140+
CINNValuePack pack_args = args[0];
141+
CHECK(!pack_args.empty())
142+
<< "at least one input tensor for " << op_name << " compute\n";
143+
144+
CHECK_EQ(pack_args.size(), 2);
145+
CHECK(pack_args[1].is_string());
146+
std::string tensor_name = pack_args[1].operator std::string();
147+
148+
Expr A = pack_args[0];
149+
CHECK(A.as_tensor());
150+
CHECK(!output_shapes.empty());
151+
auto tensor_A = A.as_tensor_ref();
152+
auto stages = CreateStages({tensor_A});
153+
VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ")
154+
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
155+
156+
CHECK_EQ(pack_args.size(), 2U);
157+
tensor_name = pack_args[1].operator std::string();
158+
159+
ir::Tensor out = Reciprocal(tensor_A, tensor_name);
160+
std::vector<CINNValue> res;
161+
stages->InsertLazily(out);
162+
res.push_back(CINNValue(out));
163+
CHECK(!out_type.empty())
164+
<< "Output type of Reciprocal is empty! Please check.\n";
165+
res.push_back(CINNValue(stages));
166+
*ret = CINNValuePack{res};
167+
});
168+
169+
auto strategy = std::make_shared<framework::OpStrategy>();
170+
strategy->AddImpl(
171+
reciprocal_compute, lang::PackedFunc(), "strategy.reciprocal.x86", 1);
172+
return strategy;
173+
}
174+
128175
std::vector<framework::shape_t> InferShapeForReciprocal(
129176
const std::vector<framework::shape_t> &inputs_shape,
130177
const framework::AttrMapType &attrs) {
@@ -153,6 +200,8 @@ CINN_REGISTER_HELPER(reciprocal_ops) {
153200
.set_num_outputs(1)
154201
.set_attr<cinn::hlir::framework::StrategyFunction>(
155202
"CINNStrategy", cinn::hlir::op::StrategyForReciprocal)
203+
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
204+
"CINNStrategySymbolic", cinn::hlir::op::StrategyForReciprocalSymbolic)
156205
.set_attr("infershape",
157206
MakeOpFunction(cinn::hlir::op::InferShapeForReciprocal))
158207
.set_attr("inferdtype",

paddle/cinn/optim/eliminate_common_factor_of_local_index.cc

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,36 @@ CollectLocalVarToIndexes(ir::Expr* expr) {
135135
gather_prohibited_local_var_visitor.prohibited_local_vars());
136136
}
137137

138-
int ExtractNumberFromExpr(const ir::Expr& expr) {
138+
int ExtractMulNumberFromExpr(const ir::Expr& expr) {
139139
ir::Expr simplied_expr = cinn::common::AutoSimplify(expr);
140140
if (simplied_expr.is_constant()) {
141141
return static_cast<int>(simplied_expr.get_constant());
142142
} else if (expr.As<ir::Mul>()) {
143143
auto mul = expr.As<ir::Mul>();
144-
return std::max(ExtractNumberFromExpr(mul->a()),
145-
ExtractNumberFromExpr(mul->b()));
144+
return ExtractMulNumberFromExpr(mul->a()) *
145+
ExtractMulNumberFromExpr(mul->b());
146146
} else {
147147
VLOG(6) << "Not supported for calculating gcd, expr = " << expr;
148148
return 1;
149149
}
150150
PADDLE_THROW(phi::errors::Fatal("Dead code"));
151151
}
152152

153+
int ExtractAddNumberFromExpr(const ir::Expr& expr) {
154+
ir::Expr simplied_expr = cinn::common::AutoSimplify(expr);
155+
if (simplied_expr.is_constant()) {
156+
return static_cast<int>(simplied_expr.get_constant());
157+
} else if (expr.As<ir::Add>()) {
158+
auto add = expr.As<ir::Add>();
159+
return ExtractAddNumberFromExpr(add->a()) +
160+
ExtractAddNumberFromExpr(add->b());
161+
} else {
162+
VLOG(6) << "Not supported for calculating offset, expr = " << expr;
163+
return 0;
164+
}
165+
PADDLE_THROW(phi::errors::Fatal("Dead code"));
166+
}
167+
153168
int gcd(int a, int b) {
154169
if (b == 0) {
155170
return a == 0 ? 1 : a;
@@ -170,7 +185,7 @@ struct CommonFactorTrait<Gcd> {
170185
// Note (Hongyu Jia): Currently, we only calculates gcd of int factors.
171186
static ir::Expr Calculate(const ir::Expr& expr1, const ir::Expr& expr2) {
172187
return ir::Expr(
173-
gcd(ExtractNumberFromExpr(expr1), ExtractNumberFromExpr(expr2)));
188+
gcd(ExtractMulNumberFromExpr(expr1), ExtractMulNumberFromExpr(expr2)));
174189
}
175190

176191
static ir::Expr Simplify(const ir::Expr& expr, const ir::Expr& factor) {
@@ -188,11 +203,8 @@ struct CommonFactorTrait<Offset> {
188203
static const ir::Expr unit;
189204

190205
static ir::Expr Calculate(const ir::Expr& expr1, const ir::Expr& expr2) {
191-
int offset1 =
192-
expr1.is_constant() ? static_cast<int>(expr1.get_constant()) : 0;
193-
int offset2 =
194-
expr2.is_constant() ? static_cast<int>(expr2.get_constant()) : 0;
195-
return ir::Expr(std::min(offset1, offset2));
206+
return ir::Expr(std::min(ExtractAddNumberFromExpr(expr1),
207+
ExtractAddNumberFromExpr(expr2)));
196208
}
197209

198210
static ir::Expr Simplify(const ir::Expr& expr, const ir::Expr& factor) {

paddle/fluid/ir_adaptor/translator/program_translator.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ static std::vector<std::string> GetExternalInputs(const BlockDesc& block) {
126126
std::unordered_set<std::string> inner_outputs;
127127
for (auto op_desc : block.AllOps()) {
128128
for (const auto& n : op_desc->Inputs()) {
129+
if (op_desc->Type() == "transpose2_grad" && n.first == "XShape") {
130+
continue;
131+
}
129132
const auto& input_var_names = n.second;
130133
for (const auto& var_name : input_var_names) {
131134
if (inner_outputs.count(var_name) == 0) {

paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,18 @@ OperationDistAttribute OperationDistAttribute::get(
121121
const std::vector<pir::Attribute>& result_attrs) {
122122
auto check_dist_attr = [=](pir::Attribute attr) {
123123
auto dist_attr = attr.dyn_cast<TensorDistAttribute>();
124-
PADDLE_ENFORCE_EQ(mesh,
125-
dist_attr.process_mesh_attr(),
126-
common::errors::PreconditionNotMet(
127-
"operand_dist_attrs element's mesh(%s) not equal "
128-
"to input mesh(%s)"));
124+
auto ids = mesh.process_ids();
125+
for (const auto& id : dist_attr.process_mesh_attr().process_ids()) {
126+
PADDLE_ENFORCE_EQ(std::find(ids.begin(), ids.end(), id) != ids.end(),
127+
true,
128+
common::errors::PreconditionNotMet(
129+
"operand_dist_attrs element's mesh(%s) not belong "
130+
"to input mesh(%s)",
131+
dist_attr.process_mesh_attr(),
132+
mesh));
133+
}
129134
};
135+
130136
for (auto attr : operand_attrs) {
131137
// NOTE: The operand dist attr maybe empty while the corresponding input is
132138
// optional.

paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const {
7474
phi::distributed::auto_parallel::str_join(
7575
tensor_dist_attr.process_mesh_attr().shape()) +
7676
"]";
77+
os << ",process_ids:[" +
78+
phi::distributed::auto_parallel::str_join(
79+
tensor_dist_attr.process_mesh_attr().process_ids()) +
80+
"]";
7781
os << ",dims_mappings:[" +
7882
phi::distributed::auto_parallel::str_join(
7983
tensor_dist_attr.dims_mapping()) +
@@ -111,6 +115,10 @@ void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const {
111115
phi::distributed::auto_parallel::str_join(
112116
dist_attr.process_mesh_attr().shape()) +
113117
"],";
118+
os << "process_ids:[" +
119+
phi::distributed::auto_parallel::str_join(
120+
dist_attr.process_mesh_attr().process_ids()) +
121+
"],";
114122
}
115123
os << "dims_maping:[" +
116124
phi::distributed::auto_parallel::str_join(
@@ -145,6 +153,10 @@ void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const {
145153
phi::distributed::auto_parallel::str_join(
146154
dist_attr.process_mesh_attr().shape()) +
147155
"],";
156+
os << "process_ids:[" +
157+
phi::distributed::auto_parallel::str_join(
158+
dist_attr.process_mesh_attr().process_ids()) +
159+
"],";
148160
}
149161
os << "dims_maping:[" +
150162
phi::distributed::auto_parallel::str_join(

paddle/fluid/pir/dialect/distributed/ir/dist_op.cc

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,11 @@ void ReshardOp::VerifySig() {
214214
1u,
215215
common::errors::PreconditionNotMet(
216216
"The size %d of inputs must be equal to 1.", input_size));
217-
PADDLE_ENFORCE_EQ((*this)
218-
->operand_source(0)
219-
.type()
220-
.isa<paddle::dialect::DistDenseTensorType>(),
217+
PADDLE_ENFORCE_EQ(!(*this)->operand_source(0) ||
218+
(*this) // reshard allow NULL TYPE as input
219+
->operand_source(0)
220+
.type()
221+
.isa<paddle::dialect::DistDenseTensorType>(),
221222
true,
222223
common::errors::PreconditionNotMet(
223224
"Type validation failed for the 0th input."));
@@ -241,7 +242,13 @@ void ReshardOp::VerifySig() {
241242
common::errors::PreconditionNotMet(
242243
"The size %d of outputs must be equal to 1.", output_size));
243244
PADDLE_ENFORCE_EQ(
244-
(*this)->result(0).type().isa<paddle::dialect::DistDenseTensorType>(),
245+
!(*this)->result(0) ||
246+
(*this)
247+
->result(0)
248+
.type()
249+
.isa<paddle::dialect::DistDenseTensorType>(), // reshard allow
250+
// NULL TYPE as
251+
// output
245252
true,
246253
common::errors::PreconditionNotMet(
247254
"Type validation failed for the 0th output."));
@@ -267,12 +274,34 @@ void ReshardOp::VerifySig() {
267274
VLOG(4) << "End Verifying for: ShardTensorOp.";
268275
}
269276

277+
ProcessMeshAttribute MergeMeshes(const ProcessMeshAttribute& mesh1,
278+
const ProcessMeshAttribute& mesh2) {
279+
if (mesh1 == mesh2) return mesh1;
280+
// Combine the two ids
281+
std::vector<int64_t> merged_ids;
282+
std::vector<int64_t> ids1 = mesh1.process_ids();
283+
std::vector<int64_t> ids2 = mesh2.process_ids();
284+
285+
merged_ids.reserve(ids1.size() + ids2.size());
286+
merged_ids.insert(merged_ids.end(), ids1.begin(), ids1.end());
287+
merged_ids.insert(merged_ids.end(), ids2.begin(), ids2.end());
288+
289+
// Remove duplicates
290+
std::sort(merged_ids.begin(), merged_ids.end());
291+
auto last = std::unique(merged_ids.begin(), merged_ids.end());
292+
merged_ids.erase(last, merged_ids.end());
293+
294+
return ProcessMeshAttribute::get(
295+
pir::IrContext::Instance(),
296+
{static_cast<int64_t>(merged_ids.size())}, // flatten mesh shape
297+
merged_ids,
298+
{"merged"});
299+
}
300+
270301
void ReshardOp::Build(pir::Builder& builder,
271302
pir::OperationArgument& argument,
272303
pir::Value input,
273304
TensorDistAttribute tensor_dist_attr) {
274-
VLOG(4) << "Start build ReshardOp";
275-
276305
paddle::dialect::DistDenseTensorType input_tensor_type;
277306
if (input.type().isa<paddle::dialect::DistDenseTensorType>()) {
278307
input_tensor_type =
@@ -288,7 +317,8 @@ void ReshardOp::Build(pir::Builder& builder,
288317
VLOG(4) << "Builder construction attributes";
289318
pir::Attribute op_dist_attr = OperationDistAttribute::get(
290319
pir::IrContext::Instance(),
291-
input_tensor_type.tensor_dist_attr().process_mesh_attr(),
320+
MergeMeshes(input_tensor_type.tensor_dist_attr().process_mesh_attr(),
321+
tensor_dist_attr.process_mesh_attr()),
292322
std::vector<pir::Attribute>{input_tensor_type.tensor_dist_attr()},
293323
std::vector<pir::Attribute>{tensor_dist_attr});
294324
argument.AddAttribute("op_dist_attr", op_dist_attr);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ OP_SAME_OPERANDS_AND_RESULT(Print)
100100
OP_SAME_OPERANDS_AND_RESULT(PutAlongAxis)
101101
OP_SAME_OPERANDS_AND_RESULT(PutAlongAxis_)
102102
OP_SAME_OPERANDS_AND_RESULT(Real)
103+
OP_SAME_OPERANDS_AND_RESULT(Reciprocal)
104+
OP_SAME_OPERANDS_AND_RESULT(Reciprocal_)
103105
OP_SAME_OPERANDS_AND_RESULT(Relu)
104106
OP_SAME_OPERANDS_AND_RESULT(Relu6)
105107
OP_SAME_OPERANDS_AND_RESULT(Relu_)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Print)
9191
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PutAlongAxis)
9292
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PutAlongAxis_)
9393
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Real)
94+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reciprocal)
95+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reciprocal_)
9496
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu)
9597
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu6)
9698
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu_)

paddle/fluid/pybind/exception.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void BindException(pybind11::module* m) {
4242
try {
4343
if (p) std::rethrow_exception(p);
4444
} catch (const platform::EOFException& e) {
45-
eof(e.what());
45+
pybind11::set_error(eof, e.what());
4646
} catch (const memory::allocation::BadAlloc& e) {
4747
PyErr_SetString(PyExc_MemoryError, e.what());
4848
} catch (const platform::EnforceNotMet& e) {
@@ -77,7 +77,7 @@ void BindException(pybind11::module* m) {
7777
PyErr_SetString(PyExc_TypeError, e.what());
7878
break;
7979
default:
80-
exc(e.what());
80+
pybind11::set_error(exc, e.what());
8181
break;
8282
}
8383
}

paddle/fluid/pybind/pir.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <Python.h>
1818
#include <algorithm>
19+
#include <iterator>
1920
#include <memory>
2021
#include <sstream>
2122
#include <string>
@@ -1126,6 +1127,43 @@ struct PyInsertionPoint {
11261127
void BindInsertionPoint(pybind11::module *m) {
11271128
py::class_<PyInsertionPoint> ir_insertion_point(*m, "InsertionPoint", R"DOC(
11281129
InsertionPoint class represents the insertion point in the Builder.)DOC");
1130+
1131+
ir_insertion_point
1132+
.def(
1133+
"next",
1134+
[](PyInsertionPoint &self) -> Operation & {
1135+
if (self.value.second == self.value.first->end()) {
1136+
PADDLE_THROW(common::errors::InvalidArgument(
1137+
"The insertion point is already at the end and can't call "
1138+
"next()."));
1139+
}
1140+
return *(self.value.second++);
1141+
},
1142+
return_value_policy::reference)
1143+
.def(
1144+
"prev",
1145+
[](PyInsertionPoint &self) -> Operation & {
1146+
if (self.value.second == self.value.first->begin()) {
1147+
PADDLE_THROW(common::errors::InvalidArgument(
1148+
"The insertion point is already at the begin and can't call "
1149+
"prev()."));
1150+
}
1151+
return *(self.value.second--);
1152+
},
1153+
return_value_policy::reference)
1154+
.def(
1155+
"get_operation",
1156+
[](PyInsertionPoint &self) -> Operation & {
1157+
if (self.value.second == self.value.first->begin()) {
1158+
PADDLE_THROW(common::errors::InvalidArgument(
1159+
"The insertion point is already at the begin."));
1160+
} else if (self.value.second == self.value.first->end()) {
1161+
PADDLE_THROW(common::errors::InvalidArgument(
1162+
"The insertion point is already at the end."));
1163+
}
1164+
return *(self.value.second);
1165+
},
1166+
return_value_policy::reference);
11291167
}
11301168

11311169
std::list<Operation *>::const_iterator list_offset(const Block *block,

0 commit comments

Comments
 (0)