@@ -21,16 +21,16 @@ namespace framework {
2121// / @cond HIDDEN
2222
2323template <int i>
24- Dim<i> make_dim (const int * d) {
24+ Dim<i> make_dim (const int64_t * d) {
2525 return Dim<i>(*d, make_dim<i - 1 >(d + 1 ));
2626}
2727
2828template <>
29- Dim<1 > make_dim<1 >(const int * d) {
29+ Dim<1 > make_dim<1 >(const int64_t * d) {
3030 return Dim<1 >(*d);
3131}
3232
33- void make_ddim (DDim& ddim, const int * dims, int n) {
33+ void make_ddim (DDim& ddim, const int64_t * dims, int n) {
3434 switch (n) {
3535 case 1 :
3636 ddim = make_dim<1 >(dims);
@@ -67,39 +67,39 @@ void make_ddim(DDim& ddim, const int* dims, int n) {
6767
6868// / @endcond
6969
70- DDim make_ddim (std::initializer_list<int > dims) {
70+ DDim make_ddim (std::initializer_list<int64_t > dims) {
7171 DDim result (make_dim (0 ));
7272 make_ddim (result, dims.begin (), dims.size ());
7373 return result;
7474}
7575
76- DDim make_ddim (const std::vector<int >& dims) {
76+ DDim make_ddim (const std::vector<int64_t >& dims) {
7777 DDim result (make_dim (0 ));
7878 make_ddim (result, &dims[0 ], dims.size ());
7979 return result;
8080}
8181
8282// / @cond HIDDEN
8383// XXX For some reason, putting this in an anonymous namespace causes errors
84- class DynamicMutableIndexer : public boost ::static_visitor<int &> {
84+ class DynamicMutableIndexer : public boost ::static_visitor<int64_t &> {
8585 public:
8686 explicit DynamicMutableIndexer (int idx) : idx_(idx) {}
8787
8888 template <int D>
89- int & operator ()(Dim<D>& dim) const {
89+ int64_t & operator ()(Dim<D>& dim) const {
9090 return dim[idx_];
9191 }
9292
9393 private:
9494 int idx_;
9595};
9696
97- class DynamicConstIndexer : public boost ::static_visitor<int > {
97+ class DynamicConstIndexer : public boost ::static_visitor<int64_t > {
9898 public:
9999 explicit DynamicConstIndexer (int idx) : idx_(idx) {}
100100
101101 template <int D>
102- int operator ()(const Dim<D>& dim) const {
102+ int64_t operator ()(const Dim<D>& dim) const {
103103 return dim[idx_];
104104 }
105105
@@ -109,22 +109,22 @@ class DynamicConstIndexer : public boost::static_visitor<int> {
109109
110110// / @endcond
111111
112- int & DDim::operator [](int idx) {
112+ int64_t & DDim::operator [](int idx) {
113113 return boost::apply_visitor (DynamicMutableIndexer (idx), var);
114114}
115115
116- int DDim::operator [](int idx) const {
116+ int64_t DDim::operator [](int idx) const {
117117 return boost::apply_visitor (DynamicConstIndexer (idx), var);
118118}
119119
120- ssize_t DDim::size () const { return arity (*this ); }
120+ int64_t DDim::size () const { return arity (*this ); }
121121
122122bool DDim::operator ==(DDim d) const {
123123 if (var.which () != d.getVar ().which ()) {
124124 return false ;
125125 } else {
126- std::vector<int > v1 = vectorize (*this );
127- std::vector<int > v2 = vectorize (d);
126+ std::vector<int64_t > v1 = vectorize (*this );
127+ std::vector<int64_t > v2 = vectorize (d);
128128
129129 for (unsigned int i = 0 ; i < v1.size (); i++) {
130130 if (v1[i] != v2[i]) {
@@ -139,10 +139,10 @@ bool DDim::operator==(DDim d) const {
139139bool DDim::operator !=(DDim d) const { return !(*this == d); }
140140
141141DDim DDim::operator +(DDim d) const {
142- std::vector<int > v1 = vectorize (*this );
143- std::vector<int > v2 = vectorize (d);
142+ std::vector<int64_t > v1 = vectorize (*this );
143+ std::vector<int64_t > v2 = vectorize (d);
144144
145- std::vector<int > v3;
145+ std::vector<int64_t > v3;
146146
147147 assert (v1.size () == v2.size ());
148148
@@ -154,10 +154,10 @@ DDim DDim::operator+(DDim d) const {
154154}
155155
156156DDim DDim::operator *(DDim d) const {
157- std::vector<int > v1 = vectorize (*this );
158- std::vector<int > v2 = vectorize (d);
157+ std::vector<int64_t > v1 = vectorize (*this );
158+ std::vector<int64_t > v2 = vectorize (d);
159159
160- std::vector<int > v3;
160+ std::vector<int64_t > v3;
161161
162162 assert (v1.size () == v2.size ());
163163
@@ -168,15 +168,15 @@ DDim DDim::operator*(DDim d) const {
168168 return make_ddim (v3);
169169}
170170
171- int get (const DDim& ddim, int idx) { return ddim[idx]; }
171+ int64_t get (const DDim& ddim, int idx) { return ddim[idx]; }
172172
173173void set (DDim& ddim, int idx, int value) { ddim[idx] = value; }
174174
175175// / @cond HIDDEN
176176struct VectorizeVisitor : public boost ::static_visitor<> {
177- std::vector<int >& vector;
177+ std::vector<int64_t >& vector;
178178
179- explicit VectorizeVisitor (std::vector<int >& v) : vector(v) {}
179+ explicit VectorizeVisitor (std::vector<int64_t >& v) : vector(v) {}
180180
181181 template <typename T>
182182 void operator ()(const T& t) {
@@ -188,31 +188,31 @@ struct VectorizeVisitor : public boost::static_visitor<> {
188188};
189189// / @endcond
190190
191- std::vector<int > vectorize (const DDim& ddim) {
192- std::vector<int > result;
191+ std::vector<int64_t > vectorize (const DDim& ddim) {
192+ std::vector<int64_t > result;
193193 VectorizeVisitor visitor (result);
194194 boost::apply_visitor (visitor, ddim);
195195 return result;
196196}
197197
198- struct ProductVisitor : public boost ::static_visitor<ssize_t > {
198+ struct ProductVisitor : public boost ::static_visitor<int64_t > {
199199 template <int D>
200- ssize_t operator ()(const Dim<D>& dim) {
200+ int64_t operator ()(const Dim<D>& dim) {
201201 return product (dim);
202202 }
203203};
204204
205- ssize_t product (const DDim& ddim) {
205+ int64_t product (const DDim& ddim) {
206206 ProductVisitor visitor;
207207 return boost::apply_visitor (visitor, ddim);
208208}
209209
210210struct SliceVectorizeVisitor : public boost ::static_visitor<> {
211- std::vector<int >& vector;
211+ std::vector<int64_t >& vector;
212212 int begin;
213213 int end;
214214
215- SliceVectorizeVisitor (std::vector<int >& v, int b, int e)
215+ SliceVectorizeVisitor (std::vector<int64_t >& v, int b, int e)
216216 : vector(v), begin(b), end(e) {
217217 PADDLE_ENFORCE (begin < end,
218218 " Begin index must be less than end index in ddim slice." );
@@ -240,7 +240,7 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
240240};
241241
242242DDim slice_ddim (const DDim& dim, int begin, int end) {
243- std::vector<int > vec;
243+ std::vector<int64_t > vec;
244244 vec.reserve (end - begin);
245245 SliceVectorizeVisitor visitor (vec, begin, end);
246246 boost::apply_visitor (visitor, dim);
@@ -280,7 +280,7 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
280280 return os;
281281}
282282
283- DDim::DDim (std::initializer_list<int > init_list) {
283+ DDim::DDim (std::initializer_list<int64_t > init_list) {
284284 *this = make_ddim (init_list);
285285}
286286} // namespace framework
0 commit comments