Skip to content

Commit 511b6e2

Browse files
authored
Merge pull request #3900 from QiJune/dim_int64
make dim int to int64_t
2 parents b3afe30 + 52f2bc1 commit 511b6e2

File tree

12 files changed

+105
-92
lines changed

12 files changed

+105
-92
lines changed

paddle/framework/ddim.cc

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ namespace framework {
2121
/// @cond HIDDEN
2222

2323
template <int i>
24-
Dim<i> make_dim(const int* d) {
24+
Dim<i> make_dim(const int64_t* d) {
2525
return Dim<i>(*d, make_dim<i - 1>(d + 1));
2626
}
2727

2828
template <>
29-
Dim<1> make_dim<1>(const int* d) {
29+
Dim<1> make_dim<1>(const int64_t* d) {
3030
return Dim<1>(*d);
3131
}
3232

33-
void make_ddim(DDim& ddim, const int* dims, int n) {
33+
void make_ddim(DDim& ddim, const int64_t* dims, int n) {
3434
switch (n) {
3535
case 1:
3636
ddim = make_dim<1>(dims);
@@ -67,39 +67,39 @@ void make_ddim(DDim& ddim, const int* dims, int n) {
6767

6868
/// @endcond
6969

70-
DDim make_ddim(std::initializer_list<int> dims) {
70+
DDim make_ddim(std::initializer_list<int64_t> dims) {
7171
DDim result(make_dim(0));
7272
make_ddim(result, dims.begin(), dims.size());
7373
return result;
7474
}
7575

76-
DDim make_ddim(const std::vector<int>& dims) {
76+
DDim make_ddim(const std::vector<int64_t>& dims) {
7777
DDim result(make_dim(0));
7878
make_ddim(result, &dims[0], dims.size());
7979
return result;
8080
}
8181

8282
/// @cond HIDDEN
8383
// XXX For some reason, putting this in an anonymous namespace causes errors
84-
class DynamicMutableIndexer : public boost::static_visitor<int&> {
84+
class DynamicMutableIndexer : public boost::static_visitor<int64_t&> {
8585
public:
8686
explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
8787

8888
template <int D>
89-
int& operator()(Dim<D>& dim) const {
89+
int64_t& operator()(Dim<D>& dim) const {
9090
return dim[idx_];
9191
}
9292

9393
private:
9494
int idx_;
9595
};
9696

97-
class DynamicConstIndexer : public boost::static_visitor<int> {
97+
class DynamicConstIndexer : public boost::static_visitor<int64_t> {
9898
public:
9999
explicit DynamicConstIndexer(int idx) : idx_(idx) {}
100100

101101
template <int D>
102-
int operator()(const Dim<D>& dim) const {
102+
int64_t operator()(const Dim<D>& dim) const {
103103
return dim[idx_];
104104
}
105105

@@ -109,22 +109,22 @@ class DynamicConstIndexer : public boost::static_visitor<int> {
109109

110110
/// @endcond
111111

112-
int& DDim::operator[](int idx) {
112+
int64_t& DDim::operator[](int idx) {
113113
return boost::apply_visitor(DynamicMutableIndexer(idx), var);
114114
}
115115

116-
int DDim::operator[](int idx) const {
116+
int64_t DDim::operator[](int idx) const {
117117
return boost::apply_visitor(DynamicConstIndexer(idx), var);
118118
}
119119

120-
ssize_t DDim::size() const { return arity(*this); }
120+
int64_t DDim::size() const { return arity(*this); }
121121

122122
bool DDim::operator==(DDim d) const {
123123
if (var.which() != d.getVar().which()) {
124124
return false;
125125
} else {
126-
std::vector<int> v1 = vectorize(*this);
127-
std::vector<int> v2 = vectorize(d);
126+
std::vector<int64_t> v1 = vectorize(*this);
127+
std::vector<int64_t> v2 = vectorize(d);
128128

129129
for (unsigned int i = 0; i < v1.size(); i++) {
130130
if (v1[i] != v2[i]) {
@@ -139,10 +139,10 @@ bool DDim::operator==(DDim d) const {
139139
bool DDim::operator!=(DDim d) const { return !(*this == d); }
140140

141141
DDim DDim::operator+(DDim d) const {
142-
std::vector<int> v1 = vectorize(*this);
143-
std::vector<int> v2 = vectorize(d);
142+
std::vector<int64_t> v1 = vectorize(*this);
143+
std::vector<int64_t> v2 = vectorize(d);
144144

145-
std::vector<int> v3;
145+
std::vector<int64_t> v3;
146146

147147
assert(v1.size() == v2.size());
148148

@@ -154,10 +154,10 @@ DDim DDim::operator+(DDim d) const {
154154
}
155155

156156
DDim DDim::operator*(DDim d) const {
157-
std::vector<int> v1 = vectorize(*this);
158-
std::vector<int> v2 = vectorize(d);
157+
std::vector<int64_t> v1 = vectorize(*this);
158+
std::vector<int64_t> v2 = vectorize(d);
159159

160-
std::vector<int> v3;
160+
std::vector<int64_t> v3;
161161

162162
assert(v1.size() == v2.size());
163163

@@ -168,15 +168,15 @@ DDim DDim::operator*(DDim d) const {
168168
return make_ddim(v3);
169169
}
170170

171-
int get(const DDim& ddim, int idx) { return ddim[idx]; }
171+
int64_t get(const DDim& ddim, int idx) { return ddim[idx]; }
172172

173173
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }
174174

175175
/// @cond HIDDEN
176176
struct VectorizeVisitor : public boost::static_visitor<> {
177-
std::vector<int>& vector;
177+
std::vector<int64_t>& vector;
178178

179-
explicit VectorizeVisitor(std::vector<int>& v) : vector(v) {}
179+
explicit VectorizeVisitor(std::vector<int64_t>& v) : vector(v) {}
180180

181181
template <typename T>
182182
void operator()(const T& t) {
@@ -188,31 +188,31 @@ struct VectorizeVisitor : public boost::static_visitor<> {
188188
};
189189
/// @endcond
190190

191-
std::vector<int> vectorize(const DDim& ddim) {
192-
std::vector<int> result;
191+
std::vector<int64_t> vectorize(const DDim& ddim) {
192+
std::vector<int64_t> result;
193193
VectorizeVisitor visitor(result);
194194
boost::apply_visitor(visitor, ddim);
195195
return result;
196196
}
197197

198-
struct ProductVisitor : public boost::static_visitor<ssize_t> {
198+
struct ProductVisitor : public boost::static_visitor<int64_t> {
199199
template <int D>
200-
ssize_t operator()(const Dim<D>& dim) {
200+
int64_t operator()(const Dim<D>& dim) {
201201
return product(dim);
202202
}
203203
};
204204

205-
ssize_t product(const DDim& ddim) {
205+
int64_t product(const DDim& ddim) {
206206
ProductVisitor visitor;
207207
return boost::apply_visitor(visitor, ddim);
208208
}
209209

210210
struct SliceVectorizeVisitor : public boost::static_visitor<> {
211-
std::vector<int>& vector;
211+
std::vector<int64_t>& vector;
212212
int begin;
213213
int end;
214214

215-
SliceVectorizeVisitor(std::vector<int>& v, int b, int e)
215+
SliceVectorizeVisitor(std::vector<int64_t>& v, int b, int e)
216216
: vector(v), begin(b), end(e) {
217217
PADDLE_ENFORCE(begin < end,
218218
"Begin index must be less than end index in ddim slice.");
@@ -240,7 +240,7 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
240240
};
241241

242242
DDim slice_ddim(const DDim& dim, int begin, int end) {
243-
std::vector<int> vec;
243+
std::vector<int64_t> vec;
244244
vec.reserve(end - begin);
245245
SliceVectorizeVisitor visitor(vec, begin, end);
246246
boost::apply_visitor(visitor, dim);
@@ -280,7 +280,7 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
280280
return os;
281281
}
282282

283-
DDim::DDim(std::initializer_list<int> init_list) {
283+
DDim::DDim(std::initializer_list<int64_t> init_list) {
284284
*this = make_ddim(init_list);
285285
}
286286
} // namespace framework

paddle/framework/ddim.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ struct DDim {
4040
template <int D>
4141
explicit DDim(const Dim<D>& in) : var(in) {}
4242

43-
/*implicit*/ DDim(std::initializer_list<int> init_list);
43+
/*implicit*/ DDim(std::initializer_list<int64_t> init_list);
4444

4545
template <int D>
4646
DDim& operator=(const Dim<D>& in) {
4747
var = in;
4848
return *this;
4949
}
5050

51-
int& operator[](int idx);
52-
int operator[](int idx) const;
51+
int64_t& operator[](int idx);
52+
int64_t operator[](int idx) const;
5353

5454
template <typename Visitor>
5555
typename Visitor::result_type apply_visitor(Visitor& visitor) {
@@ -71,30 +71,30 @@ struct DDim {
7171

7272
DDim operator*(DDim d) const;
7373

74-
ssize_t size() const;
74+
int64_t size() const;
7575
};
7676

7777
/**
78-
* \brief Make a DDim from std::vector<int>
78+
* \brief Make a DDim from std::vector<int64_t>
7979
*
8080
* \param dims An vector of ints. Must be sized between [1, 9]
8181
*/
82-
DDim make_ddim(const std::vector<int>& dims);
82+
DDim make_ddim(const std::vector<int64_t>& dims);
8383

8484
/**
8585
* \brief Make a DDim from an initializer list
8686
*
8787
* \param dims An initializer list of ints. Must be sized between [1, 9]
8888
*
8989
*/
90-
DDim make_ddim(std::initializer_list<int> dims);
90+
DDim make_ddim(std::initializer_list<int64_t> dims);
9191

92-
int get(const DDim& dim, int idx);
92+
int64_t get(const DDim& dim, int idx);
9393
void set(DDim& dim, int idx, int val);
9494

95-
std::vector<int> vectorize(const DDim& ddim);
95+
std::vector<int64_t> vectorize(const DDim& ddim);
9696

97-
ssize_t product(const DDim& ddim);
97+
int64_t product(const DDim& ddim);
9898

9999
/**
100100
* \brief Slice a ddim

paddle/framework/ddim_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ TEST(DDim, Equality) {
1212
EXPECT_EQ(ddim[2], 5);
1313

1414
// construct a DDim from a vector
15-
std::vector<int> vec({9, 1, 5});
15+
std::vector<int64_t> vec({9, 1, 5});
1616
paddle::framework::DDim vddim = paddle::framework::make_ddim(vec);
1717
EXPECT_EQ(ddim[0], 9);
1818
EXPECT_EQ(ddim[1], 1);
@@ -25,7 +25,7 @@ TEST(DDim, Equality) {
2525
EXPECT_EQ(paddle::framework::get(ddim, 0), 6);
2626

2727
// vectorize a DDim
28-
std::vector<int> res_vec = paddle::framework::vectorize(vddim);
28+
std::vector<int64_t> res_vec = paddle::framework::vectorize(vddim);
2929
EXPECT_EQ(res_vec[0], 9);
3030
EXPECT_EQ(res_vec[1], 1);
3131
EXPECT_EQ(res_vec[2], 5);

0 commit comments

Comments
 (0)