33#include < cstring>
44
55#include " cinn/common/common.h"
6+ #include " cinn/ir/ir_operators.h"
67#include " cinn/ir/ir_visitor.h"
78#include " cinn/ir/operation.h"
89#include " cinn/poly/stage.h"
910
1011namespace cinn {
1112namespace ir {
1213
14+ namespace detail {
15+
16+ Expr ExpandTo1DIndice (const std::vector<Expr> &shape, const std::vector<Expr> &indices) {
17+ CHECK_EQ (shape.size (), indices.size ());
18+ Expr res = indices.front () * shape[1 ];
19+ for (int i = 1 ; i < shape.size () - 1 ; i++) {
20+ res = res + indices[i] * shape[i + 1 ];
21+ }
22+ if (shape.size () > 1 ) res = res + indices.back ();
23+ return res;
24+ }
25+
26+ Expr ExpandTo1DIndice (const std::vector<int > &shape, const std::vector<Expr> &indices) {
27+ std::vector<Expr> shape_;
28+ for (int v : shape) shape_.push_back (Expr (v));
29+ return ExpandTo1DIndice (shape, indices);
30+ }
31+
32+ } // namespace detail
33+
1334Tensor _Tensor_::Make (const std::string &name, const std::vector<Expr> &shape, FunctionRef fn) {
35+ CHECK (!shape.empty ()) << " Tensor shape is set empty" ;
36+ CHECK (!name.empty ()) << " Tensor name is set empty" ;
1437 auto n = make_shared<_Tensor_>();
1538 n->name = name;
1639 n->shape = shape;
1740 n->operaion = fn;
1841 n->InitStage ();
42+ n->InitAxis ();
43+ n->SetDefaultBindedBuffer ();
1944 return Tensor (n);
2045}
2146
@@ -26,8 +51,13 @@ Tensor _Tensor_::Make(const std::string &name,
2651 Type dtype,
2752 const std::map<std::string, IrNodeRef> &attrs,
2853 const std::vector<Expr> &body) {
54+ CHECK (!shape.empty ()) << " Tensor shape is set empty" ;
55+ CHECK (!name.empty ()) << " Tensor name is set empty" ;
56+
2957 auto op = ComputeOp::Make (name, tag, attrs, axis, body, shape);
3058 auto *compute_op = const_cast <ComputeOp *>(op->As <ComputeOp>());
59+
60+ CHECK_EQ (axis.size (), shape.size ()) << " axis not match the dimension in shape" ;
3161 compute_op->axis = axis;
3262
3363 auto n = make_shared<_Tensor_>();
@@ -36,6 +66,7 @@ Tensor _Tensor_::Make(const std::string &name,
3666 n->shape = shape;
3767 n->set_type (dtype);
3868 n->InitStage ();
69+ n->SetDefaultBindedBuffer ();
3970 return Tensor (n);
4071}
4172
@@ -76,6 +107,12 @@ void _Tensor_::InitStage() {
76107 }
77108}
78109
110+ void _Tensor_::InitAxis () {
111+ CHECK (!shape.empty ());
112+ CHECK (axis.empty ()) << " duplicate init axis" ;
113+ axis = common::GenDefaultAxis (shape.size ());
114+ }
115+
79116isl::set _Tensor_::GenerateIslDomain () {
80117 CHECK (!shape.empty ()) << " shape should be set" ;
81118 std::vector<poly::Dim> dims;
@@ -138,5 +175,12 @@ Expr _Tensor_::body() const {
138175 NOT_IMPLEMENTED;
139176}
140177
178+ Expr _Tensor_::tensor_store_expanded_body () const {
179+ CHECK (!is_placeholder_node ()) << " placeholder should not expand store" ;
180+ std::vector<Expr> axis_;
181+ for (auto &a : axis) axis_.push_back (Expr (a));
182+ return ir::Store::Make (buffer_var, body (), detail::ExpandTo1DIndice (shape, axis_));
183+ }
184+
141185} // namespace ir
142186} // namespace cinn
0 commit comments