Skip to content

Commit 7b15a98

Browse files
authored
Merge pull request PaddlePaddle#21 from Superjomn/fea/init-tensor
init tensor
2 parents b20f50f + e60dcf8 commit 7b15a98

File tree

16 files changed

+212
-22
lines changed

16 files changed

+212
-22
lines changed

cinn/common/shared.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ struct Shared {
5050
inline bool operator<(const Shared& other) const { return p_ < other.p_; }
5151
inline bool operator==(const Shared& other) const { return p_ == other.p_; }
5252

53-
~Shared() { DesRef(p_); }
53+
~Shared() {
54+
DesRef(p_);
55+
p_ = nullptr;
56+
}
5457

5558
private:
5659
//! Increase the share count.

cinn/ir/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ cc_library(ir SRCS
66
function_definition.cc
77
ir_operators.cc
88
buffer.cc
9-
tensor.cc
9+
#tensor.cc
1010
function_base.cc
11-
operation.cc
11+
#operation.cc
1212
DEPS common boost
1313
)
1414

1515
cc_test(test_ir SRCS ir_test.cc DEPS ir)
1616
cc_test(test_ir_printer SRCS ir_printer_test.cc DEPS ir)
1717
cc_test(test_ir_operators SRCS ir_operators_test.cc DEPS ir)
18-
cc_test(test_tensor SRCS tensor_test.cc DEPS ir)
18+
#cc_test(test_tensor SRCS tensor_test.cc DEPS ir)
1919
cc_test(test_function SRCS function_test.cc DEPS ir)

cinn/ir/ir.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "cinn/ir/ir.h"
22
#include "cinn/common/pod_value.h"
33
#include "cinn/ir/ir_visitor.h"
4+
#include "cinn/lang/tensor.h"
45

56
namespace cinn {
67

@@ -219,6 +220,22 @@ Expr Call::Make(Type type,
219220
node->set_type(type);
220221
return Expr(node);
221222
}
223+
224+
void _Tensor_::Accept(IrVisitor *v) const { v->Visit(this); }
225+
226+
lang::Tensor _Tensor_::Make(const std::vector<Expr> &shape,
227+
const std::vector<Var> &iterators,
228+
Type dtype,
229+
ir::Expr expr) {
230+
CHECK_EQ(shape.size(), iterators.size()) << "dimension of the shape and the iterators should match";
231+
auto n = common::make_shared<_Tensor_>();
232+
n->dtype = dtype;
233+
n->shape = shape;
234+
n->expr = expr;
235+
n->iterators = iterators;
236+
return lang::Tensor(n);
237+
}
238+
222239
} // namespace ir
223240

224241
namespace common {
@@ -247,5 +264,4 @@ Value ToValue<ir::Var>(ir::Var v) {
247264
}
248265

249266
} // namespace common
250-
251267
} // namespace cinn

cinn/ir/ir.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
#include "cinn/ir/node.h"
1515

1616
namespace cinn {
17+
18+
namespace poly {
19+
class Element;
20+
} // namespace poly
21+
22+
namespace lang {
23+
class Tensor;
24+
} // namespace lang
25+
1726
namespace ir {
1827

1928
using common::Object;
@@ -505,6 +514,27 @@ class _IterVar_ : public IrNode {
505514
static const IrNodeTy _node_type_ = IrNodeTy::_Range_;
506515
};
507516

517+
class _Tensor_ : public IrNode {
518+
public:
519+
//! Shape of this tensor.
520+
std::vector<Expr> shape;
521+
//! Data type of this tensor.
522+
Type dtype;
523+
//! The expression that generate this tensor.
524+
ir::Expr expr;
525+
//! The iterators, we store the iterators to name the dimensions for better readability.
526+
std::vector<Var> iterators;
527+
//! Polyhedral element for analysis and schedule.
528+
std::unique_ptr<poly::Element> poly_element;
529+
530+
static lang::Tensor Make(const std::vector<Expr>& shape,
531+
const std::vector<Var>& iterators,
532+
Type dtype,
533+
ir::Expr expr);
534+
535+
void Accept(ir::IrVisitor* v) const override;
536+
};
537+
508538
static IterVar thread_axis(Range dom, const std::string& tag) {
509539
return _IterVar_::Make(dom, Var(tag), IterVarType::kThreadIndex, tag);
510540
}

cinn/ir/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class Operation : public FunctionRef {
105105
std::string name;
106106
};
107107

108+
/*
108109
class _Tensor_ : public IrNode {
109110
public:
110111
//! The shape of the tensor.
@@ -121,6 +122,7 @@ class _Tensor_ : public IrNode {
121122
122123
static const IrNodeTy _node_type_ = IrNodeTy::_Tensor_;
123124
};
125+
*/
124126

125127
class _Operation_ : public ir::FunctionBase {
126128
public:

cinn/ir/tensor_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ TEST(Tensor, func) {
3939
}
4040

4141
} // namespace ir
42-
} // namespace cinn
42+
} // namespace cinn

cinn/lang/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
1-
cc_library(lang SRCS buffer.cc tensor.cc DEPS common ir)
1+
cc_library(lang SRCS
2+
buffer.cc
3+
tensor.cc
4+
compute.cc
5+
DEPS ir poly)
6+
7+
cc_test(test_compute SRCS compute_test.cc DEPS lang)
8+
cc_test(test_tensor2 SRCS tensor_test.cc DEPS lang)

cinn/lang/compute.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "cinn/lang/compute.h"
2+
#include "cinn/poly/dim.h"
3+
#include "cinn/utils/functional.h"
4+
5+
namespace cinn {
6+
namespace lang {
7+
8+
using ir::Expr;
9+
10+
template <>
11+
Tensor Compute<compute_handle_1_t>(const std::vector<int>& dims, compute_handle_1_t handle) {
12+
CHECK_EQ(dims.size(), 1);
13+
14+
poly::Dim dim("i", 0, dims[0] - 1);
15+
16+
Var i("i", Int(32));
17+
auto expr = handle(i);
18+
std::vector<Expr> shape;
19+
for (int v : dims) shape.emplace_back(v);
20+
21+
Tensor tensor(shape, {i}, expr.type(), expr);
22+
return std::move(tensor);
23+
}
24+
25+
} // namespace lang
26+
} // namespace cinn

cinn/lang/compute.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#pragma once
2+
#include <functional>
3+
#include <utility>
4+
#include <vector>
5+
#include "cinn/ir/ir.h"
6+
#include "cinn/lang/tensor.h"
7+
8+
namespace cinn {
9+
namespace lang {
10+
11+
using ir::Var;
12+
using compute_handle_1_t = std::function<ir::Expr(Var i)>;
13+
using compute_handle_2_t = std::function<ir::Expr(Var i0, Var i1)>;
14+
using compute_handle_3_t = std::function<ir::Expr(Var i0, Var i1, Var i2)>;
15+
using compute_handle_4_t = std::function<ir::Expr(Var i0, Var i1, Var i2, Var i3)>;
16+
17+
/**
18+
* Compute a Tensor.
19+
* @param dims Dimensions.
20+
* @param iterators
21+
* @param handle
22+
*/
23+
template <typename Fn>
24+
Tensor Compute(const std::vector<int>& dims, Fn handle);
25+
26+
} // namespace lang
27+
} // namespace cinn

cinn/lang/compute_test.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#include "cinn/lang/compute.h"
2+
#include <gtest/gtest.h>
3+
#include "cinn/lang/tensor.h"
4+
5+
namespace cinn {
6+
namespace lang {} // namespace lang
7+
} // namespace cinn

0 commit comments

Comments
 (0)