Skip to content

Commit 91c9065

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into custom_stride
2 parents c005dbb + 8a523ee commit 91c9065

31 files changed

Lines changed: 673 additions & 226 deletions

paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
2222
#include "paddle/cinn/hlir/framework/pir/utils.h"
2323
#include "paddle/common/ddim.h"
24+
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h"
2425
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
2526
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
2627
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
@@ -107,8 +108,12 @@ bool MakeGenerateShapeOpAttribute(
107108
std::vector<pir::Attribute>* output_dim_expr_attrs,
108109
GenerateShapeOp::SymbolBindings* symbol_bindings) {
109110
const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape);
110-
CHECK(shape_or_data_dim_exprs.data().has_value());
111-
const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value();
111+
ExprVec data_vec =
112+
paddle::dialect::details::GetExprVecFromData(shape_or_data_dim_exprs);
113+
// CHECK(shape_or_data_dim_exprs.data().has_value());
114+
CHECK(data_vec.size());
115+
// const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value();
116+
const auto& out_dim_exprs = data_vec;
112117
return MakeGenerateShapeOpAttribute(ir_context,
113118
ShapeOrDataDimExprs4Value,
114119
out_dim_exprs,

paddle/cinn/hlir/framework/pir/group.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ std::shared_ptr<Group> Group::Clone(::pir::Block* target_block,
5252

5353
new_group->input_names = this->input_names;
5454
new_group->output_names = this->output_names;
55-
new_group->output_values = this->output_values;
5655
new_group->fn_name = this->fn_name;
5756
new_group->int_args_map = this->int_args_map;
5857
new_group->alignment_schedule_info = this->alignment_schedule_info;

paddle/fluid/framework/operator.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
17041704
all_kernels_must_compute_runtime_shape_ = true;
17051705
const Scope* cur_scope = &scope;
17061706
CheckWhetherPreparePhiData(Inputs(), Outputs(), scope);
1707+
#if defined(PADDLE_WITH_XPU)
1708+
if (std::getenv("XPU_NEED_PREPARE_PHI_DATA") != nullptr) {
1709+
need_prepare_phi_data_ = atoi(std::getenv("XPU_NEED_PREPARE_PHI_DATA"));
1710+
}
1711+
#endif
17071712
if (!enable_cache_runtime_context_) {
17081713
RuntimeContext ctx(Inputs(), Outputs(), scope);
17091714
RunImpl(scope, place, &ctx);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ bool InferSymbolicShapeElementWiseBinary(
2323
// For ElementWiseBinary ops, if the input tensor is from full op, the value
2424
// of fullop is useless, only the shape need doing broadcast
2525
bool x_from_fullop =
26-
op->operand_source(0).defining_op()->isa<paddle::dialect::FullOp>();
26+
op->operand_source(0).defining_op()
27+
? op->operand_source(0).defining_op()->isa<paddle::dialect::FullOp>()
28+
: false;
2729
if (!x_from_fullop && x_shapeordata.data().has_value()) {
2830
shape_0 = x_shapeordata.data().value();
2931
} else {
@@ -34,7 +36,9 @@ bool InferSymbolicShapeElementWiseBinary(
3436
shape_analysis->GetShapeOrDataForValue(op->operand_source(1));
3537
std::vector<symbol::DimExpr> shape_1;
3638
bool y_from_fullop =
37-
op->operand_source(1).defining_op()->isa<paddle::dialect::FullOp>();
39+
op->operand_source(1).defining_op()
40+
? op->operand_source(1).defining_op()->isa<paddle::dialect::FullOp>()
41+
: false;
3842
if (!y_from_fullop && y_shapeordata.data().has_value()) {
3943
shape_1 = y_shapeordata.data().value();
4044
} else {
@@ -224,4 +228,12 @@ bool Remainder_OpInferSymbolicShape(
224228
return InferSymbolicShapeElementWiseBinary(op, shape_analysis);
225229
}
226230

231+
bool SubtractOpInferSymbolicShape(
232+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
233+
return InferSymbolicShapeElementWiseBinary(op, shape_analysis);
234+
}
235+
bool Subtract_OpInferSymbolicShape(
236+
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
237+
return InferSymbolicShapeElementWiseBinary(op, shape_analysis);
238+
}
227239
} // namespace paddle::dialect

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(NotEqual)
5353
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NotEqual_)
5454
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Remainder)
5555
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Remainder_)
56+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Subtract)
57+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Subtract_)
5658

5759
} // namespace paddle::dialect

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,22 @@ std::vector<T> GetVectorAttr(const ::pir::Operation *op,
7575
return vec_res;
7676
}
7777

78+
inline ExprVec GetExprVecFromData(const ShapeOrData &shapeordata) {
79+
if (shapeordata.isa<TensorListExprs>()) {
80+
ExprVec result;
81+
TensorListExprs list =
82+
shapeordata.dyn_cast<symbol::TensorListShapeOrDataDimExprs>();
83+
for (size_t i = 0; i < list.size(); i++) {
84+
for (auto expr : list[i].data().value()) {
85+
result.emplace_back(expr);
86+
}
87+
}
88+
return result;
89+
} else {
90+
return shapeordata.data().value();
91+
}
92+
}
93+
7894
std::optional<std::vector<int64_t>> VecExpr2Int64(const ExprVec &expr_vec);
7995

8096
ExprVec VecInt642Expr(const std::vector<int64_t> &int_vec);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,21 @@ bool ConcatOpInferSymbolicShape(
289289
axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank));
290290

291291
if (shape_data_list[0].data().has_value()) {
292+
if (rank == 1) {
293+
ExprVec data = details::GetExprVecFromData(
294+
shape_analysis->GetShapeOrDataForValue(operand_source));
295+
const std::vector<symbol::DimExpr> shape{std::int64_t(data.size())};
296+
symbol::ShapeOrDataDimExprs shape_data{
297+
symbol::TensorShapeOrDataDimExprs(shape, data)};
298+
pir::Value res = op->result(0);
299+
shape_analysis->SetShapeOrDataForValue(res, shape_data);
300+
301+
return true;
302+
} else {
303+
PADDLE_THROW(phi::errors::Unimplemented(
304+
op->name() +
305+
" 's InferSymbolicShape can NOT deal with rank > 1 now."));
306+
}
292307
std::vector<symbol::DimExpr> data;
293308
data.reserve(shape_data_list.size());
294309
for (auto &data_elem : shape_data_list) {
@@ -436,9 +451,9 @@ bool SqueezeOpInferSymbolicShape(
436451
if (in_dims_sym[current] == 1) {
437452
should_squeeze[current] = true;
438453
} else if (!in_dims_sym[current].Has<std::int64_t>()) {
439-
PADDLE_THROW(
440-
phi::errors::Unimplemented("SqueezeOpInferSymbolicShape CAN NOT "
441-
"deal with symbol in axis now"));
454+
should_squeeze[current] = true;
455+
} else {
456+
should_squeeze[current] = true;
442457
}
443458
}
444459
}

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,14 +379,7 @@ bool Sinh_OpInferSymbolicShape(pir::Operation *op,
379379
pir::ShapeConstraintIRAnalysis *shape_analysis) {
380380
return SameOperandsAndResultShape(op, shape_analysis);
381381
}
382-
bool SubtractOpInferSymbolicShape(
383-
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
384-
return SameOperandsAndResultShape(op, shape_analysis);
385-
}
386-
bool Subtract_OpInferSymbolicShape(
387-
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
388-
return SameOperandsAndResultShape(op, shape_analysis);
389-
}
382+
390383
bool TanOpInferSymbolicShape(pir::Operation *op,
391384
pir::ShapeConstraintIRAnalysis *shape_analysis) {
392385
return SameOperandsAndResultShape(op, shape_analysis);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin)
105105
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_)
106106
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sinh)
107107
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sinh_)
108-
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Subtract)
109-
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Subtract_)
110108
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tan)
111109
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tan_)
112110
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tanh)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ bool Cumsum_OpInferSymbolicShape(
165165
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
166166
return CumsumOpInferSymbolicShape(op, shape_analysis);
167167
}
168+
168169
bool DiagEmbedOpInferSymbolicShape(
169170
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
170171
pir::Value operand_source = op->operand_source(0);
@@ -280,6 +281,7 @@ bool KthvalueOpInferSymbolicShape(
280281
shape_analysis->SetShapeOrDataForValue(op->result(1), shape_data);
281282
return true;
282283
}
284+
283285
bool ReshapeOpInferSymbolicShape(
284286
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
285287
pir::Value operand_source = op->operand_source(0);
@@ -329,10 +331,11 @@ bool ReshapeOpInferSymbolicShape(
329331
const auto &numel =
330332
GetProduct(original_shape, [](const auto &) { return true; });
331333

334+
ExprVec target_shape = details::GetExprVecFromData(operand_shape_or_data);
332335
const auto &product_exclude_minus_one =
333-
GetProduct(operand_shape_or_data.data().value(), IsNotMinusOne);
336+
GetProduct(target_shape, IsNotMinusOne);
334337

335-
const auto &input_dims = operand_shape_or_data.data().value();
338+
const auto &input_dims = target_shape;
336339

337340
std::vector<symbol::DimExpr> out_dims;
338341
out_dims.reserve(input_dims.size());

0 commit comments

Comments
 (0)