diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index c7402d56a74cb9..7157fcf5c15ce3 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -407,6 +407,7 @@ const std::string &Store::name() const { } Type Store::type() const { return value.type(); } + std::vector Store::expr_fields() { std::vector exprs({&tensor, &value}); for (auto &idx : indices) exprs.push_back(&idx); @@ -776,6 +777,9 @@ Expr Reduce::Make(Reduce::ReduceType reduce_type, n->set_type(body.type()); return Expr(n); } + +Type Reduce::type() const { return body.type().ElementOf(); } + std::vector Reduce::expr_fields() { std::vector res; if (init.defined()) { @@ -802,6 +806,12 @@ void Reduce::Verify() const { CHECK_EQ(init.type(), body.type()); } +Type Select::type() const { + PADDLE_ENFORCE_EQ( + true_value.type(), false_value.type(), "Type of Select must be same"); + return true_value.type(); +} + void Select::Verify() const { CHECK(condition.defined()); CHECK(true_value.defined()); @@ -862,12 +872,16 @@ void MultiOperandVerify(llvm::ArrayRef operands) { } } +Type Product::type() const { return operands().front().type(); } + void Product::Verify() const { CHECK_GT(operands().size(), 1UL) << "Product node should have more than 1 operands"; MultiOperandVerify(operands()); } +Type Sum::type() const { return operands().front().type(); } + void Sum::Verify() const { CHECK_GT(operands().size(), 1UL) << "Sum node should have more than 1 operands"; diff --git a/paddle/cinn/ir/ir.h b/paddle/cinn/ir/ir.h index d711e93ce61abb..72429f6f0556a1 100644 --- a/paddle/cinn/ir/ir.h +++ b/paddle/cinn/ir/ir.h @@ -476,7 +476,7 @@ struct Reduce : public ExprNode { Expr body, const std::vector& reduce_axis); - Type type() const override { return body.type().ElementOf(); } + Type type() const override; std::vector expr_fields() override; std::vector expr_fields() const override; @@ -509,10 +509,7 @@ struct Select : public ExprNode