Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class BlockDimExprsAsserter {
auto VisitEachInputAndDimExprs = [&](const auto& Visit) {
for (int i = 0; i < op.num_operands(); ++i) {
pir::Value input = op.operand_source(i);
if (!input || !input.type()) continue;
const auto& value_dim_exprs = GraphDimExprs4Value(input);
Visit(input, value_dim_exprs);
}
Expand Down Expand Up @@ -125,6 +126,7 @@ class BlockDimExprsAsserter {
return std::visit(patterns, value_dim_exprs.variant());
};
VisitEachInputAndDimExprs([&](auto value, const auto& value_dim_exprs) {
if (!value || !value.type()) return;
const auto& new_symbol_replaced = GetNewSymbolReplaced(value_dim_exprs);
shape_analysis->SetShapeOrDataForValue(value, new_symbol_replaced);
});
Expand Down Expand Up @@ -155,16 +157,19 @@ class BlockDimExprsAsserter {

void AssertDimExprForOutput(pir::Operation* op) { // NOLINT
VLOG(5) << "Add assert for result of [ " << op->name() << " ]";
if (op->num_results() == 0) return;
if (!op->HasInterface<paddle::dialect::InferSymbolicShapeInterface>()) {
LOG(INFO) << "skip the checking for [ " << op->name() << " ]";
return;
}

auto OpDimExprs4Value = MakeOpDimExprs4Value(op);
const auto& inputs = [&] {
std::vector<pir::Value> inputs;
inputs.reserve(op->num_operands());
for (int i = 0; i < op->num_operands(); ++i) {
const auto& input = op->operand_source(i);
if (!input || !input.type()) continue;
if (input.type().isa<pir::VectorType>()) {
return std::vector<pir::Value>{};
}
Expand All @@ -176,18 +181,20 @@ class BlockDimExprsAsserter {
builder_.SetInsertionPointAfter(op);
for (std::size_t i = 0; i < op->num_results(); ++i) {
pir::Value output = op->result(i);
if (!output || !output.type()) continue;
const auto& shape_or_data_dim_expr = GraphDimExprs4Value(output);
if (!shape_or_data_dim_expr.isa<symbol::TensorShapeOrDataDimExprs>())
continue;
if (shape_or_data_dim_expr.data().has_value()) {
TryAssertDimExprsForOutputData(inputs, output, OpDimExprs4Value);
TryAssertDimExprsForOutputData(op, inputs, output, OpDimExprs4Value);
} else {
TryAssertDimExprsForOutputShape(inputs, output, OpDimExprs4Value);
TryAssertDimExprsForOutputShape(op, inputs, output, OpDimExprs4Value);
}
}
}

void TryAssertDimExprsForOutputShape(
const pir::Operation* op,
const std::vector<pir::Value>& inputs,
pir::Value output,
const DimExprs4ValueT& OpDimExprs4Value) {
Expand All @@ -203,14 +210,15 @@ class BlockDimExprsAsserter {
const auto& shape_tensor_from_dim_exprs =
opt_shape_tensor_from_dim_exprs.value();
auto shape_tensor_from_infer_meta = BuildShapeTensorFromInferMeta(output);
AddAssertEqual(shape_tensor_from_dim_exprs, shape_tensor_from_infer_meta);
AddAssertEqual(
op, shape_tensor_from_dim_exprs, shape_tensor_from_infer_meta);
}

std::optional<pir::Value> BuildShapeTensorFromShapeDimExprs(
const std::vector<pir::Value>& inputs,
pir::Value output,
const DimExprs4ValueT& OpDimExprs4Value) {
const auto& shape_or_data = GraphDimExprs4Value(output);
const auto& shape_or_data = OpDimExprs4Value(output);
const auto& dim_exprs = shape_or_data.shape();
return BuildShapeTensorFromDimExprs(inputs, dim_exprs, OpDimExprs4Value);
}
Expand All @@ -219,7 +227,7 @@ class BlockDimExprsAsserter {
const std::vector<pir::Value>& inputs,
pir::Value output,
const DimExprs4ValueT& OpDimExprs4Value) {
const auto& shape_or_data = GraphDimExprs4Value(output);
const auto& shape_or_data = OpDimExprs4Value(output);
const auto& dim_exprs = shape_or_data.data();
if (!dim_exprs.has_value()) return std::nullopt;
return BuildShapeTensorFromDimExprs(
Expand Down Expand Up @@ -260,13 +268,14 @@ class BlockDimExprsAsserter {
return builder_.Build<paddle::dialect::ShapeOp>(output).out();
}

void TryAssertDimExprsForOutputData(const std::vector<pir::Value>& inputs,
void TryAssertDimExprsForOutputData(const pir::Operation* op,
const std::vector<pir::Value>& inputs,
pir::Value output,
const DimExprs4ValueT& OpDimExprs4Value) {
auto opt_shape_tensor_from_dim_exprs =
BuildShapeTensorFromDataDimExprs(inputs, output, OpDimExprs4Value);
if (!opt_shape_tensor_from_dim_exprs.has_value()) return;
AddAssertEqual(opt_shape_tensor_from_dim_exprs.value(), output);
AddAssertEqual(op, opt_shape_tensor_from_dim_exprs.value(), output);
}

size_t GetNumel(pir::Value value) {
Expand All @@ -281,7 +290,9 @@ class BlockDimExprsAsserter {
return numel;
}

void AddAssertEqual(pir::Value lhs, pir::Value rhs) {
void AddAssertEqual(const pir::Operation* op,
pir::Value lhs,
pir::Value rhs) {
size_t lhs_numel = GetNumel(lhs);
size_t rhs_numel = GetNumel(rhs);
PADDLE_ENFORCE_EQ(lhs_numel,
Expand All @@ -295,7 +306,16 @@ class BlockDimExprsAsserter {
builder_.Build<paddle::dialect::EqualOp>(lhs, rhs).out();
pir::Value all_eq =
builder_.Build<paddle::dialect::AllOp>(lhs_eq_rhs).out();
builder_.Build<paddle::dialect::AssertOp>(all_eq, lhs_eq_rhs, lhs_numel);
pir::Value assert_data =
builder_.Build<pir::CombineOp>(std::vector<pir::Value>{lhs, rhs}).out();
auto assert_op = builder_.Build<paddle::dialect::AssertOp>(
all_eq, assert_data, lhs_numel);
const std::string error_msg = "Check [" + op->name() + "_" +
std::to_string(op->id()) +
"] infer symbolic shape failed.";
assert_op->set_attribute(
paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME,
pir::StrAttribute::get(pir::IrContext::Instance(), error_msg));
}

DimExprs4ValueT GraphDimExprs4Value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,19 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y);
const auto& out_shape = shape_analysis.GetShapeOrDataForValue(op->result(0));

if (x_shape == y_shape) {
if (x_shape.shape() == y_shape.shape()) {
return false;
}

pir::Value output_dim_tensor =
GetOutputDimTensor(rewriter, x, y, &shape_analysis);
if (x_shape.shape() != out_shape.shape() ||
x_shape.data() != out_shape.data()) {
if (x_shape.shape() != out_shape.shape()) {
pir::Value broadcasted_x =
rewriter->Build<paddle::dialect::ExpandOp>(x, output_dim_tensor).out();
op->operand(0).set_source(broadcasted_x);
shape_analysis.SetShapeOrDataForValue(broadcasted_x, out_shape);
}
if (y_shape.shape() != out_shape.shape() ||
y_shape.data() != out_shape.data()) {
if (y_shape.shape() != out_shape.shape()) {
pir::Value broadcasted_y =
rewriter->Build<paddle::dialect::ExpandOp>(y, output_dim_tensor).out();
op->operand(1).set_source(broadcasted_y);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,10 @@ struct CachedDimExprToValueConverter {
pir::Value prod = ConvertToValue(operands->at(0));
for (int i = 1; i < operands->size(); ++i) {
if (operands->at(i).isa<symbol::Reciprocal<symbol::DimExpr>>()) {
const auto& [operand] =
*operands->at(i).dyn_cast<symbol::Negative<symbol::DimExpr>>();
const auto& operand =
operands->at(i)
.dyn_cast<symbol::Reciprocal<symbol::DimExpr>>()
->data;
pir::Value operand_value = ConvertToValue(operand);
prod = rewriter->Build<paddle::dialect::DivideOp>(prod, operand_value)
.out();
Expand All @@ -218,7 +220,8 @@ struct CachedDimExprToValueConverter {
pir::Value max = ConvertToValue(operands->at(0));
for (int i = 1; i < operands->size(); ++i) {
pir::Value operand_value = ConvertToValue(operands->at(i));
max = rewriter->Build<paddle::dialect::MaxOp>(max, operand_value).out();
max =
rewriter->Build<paddle::dialect::MaximumOp>(max, operand_value).out();
}
return max;
}
Expand All @@ -234,7 +237,8 @@ struct CachedDimExprToValueConverter {
pir::Value min = ConvertToValue(operands->at(0));
for (int i = 1; i < operands->size(); ++i) {
pir::Value operand_value = ConvertToValue(operands->at(i));
min = rewriter->Build<paddle::dialect::MinOp>(min, operand_value).out();
min =
rewriter->Build<paddle::dialect::MinimumOp>(min, operand_value).out();
}
return min;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,20 @@ void AssertInstruction::Run() {
value_exe_info_->GetVarByValue(val)->Get<phi::DenseTensor>();
formatter.Print(tensor, name);
}

const std::string& error_msg = [&]() -> std::string {
if (op_->HasAttribute(paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME)) {
return op_
->attribute<pir::StrAttribute>(
paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME)
.AsString();
}
return {};
}();
PADDLE_THROW(platform::errors::InvalidArgument(
"The condition variable '%s' of AssertOp must be "
"true, but received false",
value_exe_info_->GetVarName(cond_var_)));
"true, but received false. %s",
value_exe_info_->GetVarName(cond_var_),
error_msg));
}

} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,37 +26,17 @@ bool ArangeOpInferSymbolicShape(
const auto &step_shape_or_data =
shape_analysis->GetShapeOrDataForValue(op->operand_source(2));

const auto start = [&] {
symbol::DimExpr expr;
if (start_shape_or_data.data().has_value()) {
expr = start_shape_or_data.data().value()[0];
} else {
expr = start_shape_or_data.shape()[0];
}
return expr;
}();

const auto end = [&] {
symbol::DimExpr expr;
if (end_shape_or_data.data().has_value()) {
expr = end_shape_or_data.data().value()[0];
} else {
expr = end_shape_or_data.shape()[0];
}
return expr;
}();

const auto step = [&] {
symbol::DimExpr expr;
if (step_shape_or_data.data().has_value()) {
expr = step_shape_or_data.data().value()[0];
} else {
expr = step_shape_or_data.shape()[0];
}
return expr;
}();

const symbol::ShapeOrDataDimExprs &shape_data = [&] {
if (!start_shape_or_data.data().has_value() ||
!end_shape_or_data.data().has_value() ||
!step_shape_or_data.data().has_value()) {
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(std::vector<symbol::DimExpr>{
symbol::DimExpr(shape_analysis->GetNextSymName())})};
}
const auto &start = start_shape_or_data.data()->at(0);
const auto &end = end_shape_or_data.data()->at(0);
const auto &step = step_shape_or_data.data()->at(0);
std::vector<symbol::DimExpr> out_dims;
// TODO(lanxianghit, jiahy0825): here should be ceil((end - start) / step),
// but DimExpr doesn't support ceil and float now
Expand Down Expand Up @@ -135,10 +115,32 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
return sym_dims;
}();

symbol::ShapeOrDataDimExprs shape_data{
symbol::TensorShapeOrDataDimExprs(sym_dims)};
auto IsOneNumel = [&](pir::Value value) {
const auto &dims = value.type().dyn_cast<pir::DenseTensorType>().dims();
if (dims.size() == 1 && dims[0] == 1) {
return true;
}
return false;
};

auto IsIntType = [&](pir::Value value) {
const auto &dtype = value.type().dyn_cast<pir::DenseTensorType>().dtype();
return dtype.isa<pir::Int32Type>() || dtype.isa<pir::Int64Type>();
};

const auto &shape_or_data = [&]() {
if (IsOneNumel(op->result(0)) && IsIntType(op->result(0))) {
std::vector<symbol::DimExpr> data{
symbol::DimExpr(shape_analysis->GetNextSymName())};
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(sym_dims, data)};
} else {
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(sym_dims)};
}
}();

shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data);
shape_analysis->SetShapeOrDataForValue(op->result(0), shape_or_data);

return true;
}
Expand All @@ -164,9 +166,17 @@ bool EmptyOpInferSymbolicShape(pir::Operation *op,
pir::Value operand_source = op->operand_source(0);
const symbol::ShapeOrDataDimExprs &operand_shape_or_data =
shape_analysis->GetShapeOrDataForValue(operand_source);

shape_analysis->SetShapeOrDataForValue(op->result(0),
operand_shape_or_data);
PADDLE_ENFORCE_EQ(
operand_shape_or_data.data().has_value(),
true,
common::errors::InvalidArgument(
"The data of input dim_expr shape is null. When input of empty op "
"is a tensor, the data of input dim_expr shape must have value."));

shape_analysis->SetShapeOrDataForValue(
op->result(0),
symbol::TensorShapeOrDataDimExprs{
operand_shape_or_data.data().value()});
return true;
}
}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ void HasElementsOp::VerifySig() {
}

const char *AssertOp::attributes_name[1] = {"summarize"};
const char AssertOp::ERROR_INFO_ATTR_NAME[] = "error_info";

void AssertOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class AssertOp
: public pir::Op<AssertOp, OpYamlInfoInterface, pir::SideEffectTrait> {
public:
using Op::Op;
static const char ERROR_INFO_ATTR_NAME[];
static const char *name() { return "pd_op.assert"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[1];
Expand Down
2 changes: 1 addition & 1 deletion test/ir/pir/cinn/symbolic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ if(WITH_GPU)
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/:$ENV{PYTHONPATH}
FLAGS_check_infer_symbolic=1 FLAGS_enable_pir_api=1
FLAGS_cinn_bucket_compile=True FLAGS_prim_enable_dynamic=true
FLAGS_pir_apply_shape_optimization_pass=1
FLAGS_prim_all=True FLAGS_pir_apply_shape_optimization_pass=1
FLAGS_group_schedule_tiling_first=1 FLAGS_cinn_new_group_scheduler=1
${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/${cinn_pir_test_name}.py
Expand Down
14 changes: 7 additions & 7 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_nullary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ def forward(self, in_0, in_1, in_2):

class ArangeOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.start = paddle.full([1], 0)
self.end = paddle.full([1], 5)
self.step = paddle.full([1], 1)
self.start = paddle.full([1], 0, dtype='int32')
self.end = paddle.full([1], 5, dtype='int32')
self.step = paddle.full([1], 1, dtype='int32')
self.expected = ['shape[Mul(Add(S1, -S0), 1 / (S2))], data[NULL]']

def test_eval_symbolic(self):
net = ArangeNet()
input_spec = [
InputSpec(shape=[None], dtype='float32'),
InputSpec(shape=[None], dtype='float32'),
InputSpec(shape=[None], dtype='float32'),
InputSpec(shape=[1], dtype='int32'),
InputSpec(shape=[1], dtype='int32'),
InputSpec(shape=[1], dtype='int32'),
]
net = apply_to_static(net, False, input_spec)
net.eval()
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(self):

def forward(self, x):
out = paddle.empty(shape=[128, 32])
out = paddle.empty(shape=x)
out = paddle.empty(shape=x.shape)
return out


Expand Down
6 changes: 3 additions & 3 deletions test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,9 @@ class SplitOpInferSymbolicShapeTest(TestBase):
def prepare_data(self):
self.cases = [np.random.rand(4, 6, 5)]
self.expected = [
'shape[S0, S1, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, Add(S1, -3), S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, Add(S1, -1), S2], data[NULL]',
'shape[S0, 6, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 5, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
'shape[S0, 6, S2], data[NULL]',
'shape[S0, 1, S2], data[NULL], shape[S0, 2, S2], data[NULL], shape[S0, 3, S2], data[NULL]',
Expand Down
Loading