@@ -72,6 +72,19 @@ inline std::vector<int64_t> PD_TensorGetDims(PD_Tensor* tensor,
7272 return std::vector<int64_t >();
7373}
7474
75+ inline std::vector<int64_t > PD_TensorGetStrides (PD_Tensor* tensor,
76+ PD_Status* status) {
77+ int64_t nstrides = PD_TensorGetNumStrides (tensor, status);
78+ if (nstrides > 0 ) {
79+ std::vector<int64_t > shape (nstrides);
80+ for (int64_t i = 0 ; i < nstrides; ++i) {
81+ shape[i] = PD_TensorGetStride (tensor, i, status);
82+ }
83+ return shape;
84+ }
85+ return std::vector<int64_t >();
86+ }
87+
7588inline std::vector<int64_t > PD_MetaTensorGetDims (PD_MetaTensor* tensor,
7689 PD_Status* status) {
7790 int64_t ndims = PD_MetaTensorGetNumDims (tensor, status);
@@ -85,6 +98,19 @@ inline std::vector<int64_t> PD_MetaTensorGetDims(PD_MetaTensor* tensor,
8598 return std::vector<int64_t >();
8699}
87100
101+ inline std::vector<int64_t > PD_MetaTensorGetStrides (PD_MetaTensor* tensor,
102+ PD_Status* status) {
103+ int64_t nstrides = PD_MetaTensorGetNumStrides (tensor, status);
104+ if (nstrides > 0 ) {
105+ std::vector<int64_t > shape (nstrides);
106+ for (int64_t i = 0 ; i < nstrides; ++i) {
107+ shape[i] = PD_MetaTensorGetStride (tensor, i, status);
108+ }
109+ return shape;
110+ }
111+ return std::vector<int64_t >();
112+ }
113+
88114template <typename T>
89115class WrapperBase {
90116 public:
@@ -134,13 +160,27 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
134160 return holder;
135161 }
136162
163+ size_t offset () const {
164+ C_Status status;
165+ auto offset = PD_TensorGetOffset (raw_data (), &status);
166+ PD_CHECK_STATUS (status);
167+ return offset;
168+ }
169+
137170 std::vector<int64_t > dims () const {
138171 C_Status status;
139172 auto dimension = PD_TensorGetDims (raw_data (), &status);
140173 PD_CHECK_STATUS (status);
141174 return dimension;
142175 }
143176
177+ std::vector<int64_t > strides () const {
178+ C_Status status;
179+ auto strides = PD_TensorGetStrides (raw_data (), &status);
180+ PD_CHECK_STATUS (status);
181+ return strides;
182+ }
183+
144184 PD_DataType dtype () const {
145185 C_Status status;
146186 auto data_type = PD_TensorGetPDDataType (raw_data (), &status);
@@ -207,6 +247,18 @@ class DenseTensor : public WrapperBase<PD_Tensor> {
207247 PD_CHECK_STATUS (status);
208248 }
209249
250+ void set_offset (const int64_t & offset) {
251+ C_Status status;
252+ PD_TensorSetOffset (raw_data (), offset, &status);
253+ PD_CHECK_STATUS (status);
254+ }
255+
256+ void set_strides (const std::vector<int64_t >& strides) {
257+ C_Status status;
258+ PD_TensorSetStrides (raw_data (), strides.size (), strides.data (), &status);
259+ PD_CHECK_STATUS (status);
260+ }
261+
210262 void set_dtype (PD_DataType data_type) {
211263 C_Status status;
212264 PD_TensorSetDataType (raw_data (), data_type, &status);
@@ -513,6 +565,13 @@ class MetaTensor : WrapperBase<PD_MetaTensor> {
513565 return dimension;
514566 }
515567
568+ std::vector<int64_t > strides () const {
569+ C_Status status;
570+ auto strides = PD_MetaTensorGetStrides (raw_data (), &status);
571+ PD_CHECK_STATUS (status);
572+ return strides;
573+ }
574+
516575 PD_DataType dtype () const {
517576 C_Status status;
518577 auto data_type = PD_MetaTensorGetPDDataType (raw_data (), &status);
@@ -540,6 +599,13 @@ class MetaTensor : WrapperBase<PD_MetaTensor> {
540599 PD_CHECK_STATUS (status);
541600 }
542601
602+ void set_strides (const std::vector<int64_t >& strides) {
603+ C_Status status;
604+ PD_MetaTensorSetStrides (
605+ raw_data (), strides.size (), strides.data (), &status);
606+ PD_CHECK_STATUS (status);
607+ }
608+
543609 void set_dtype (PD_DataType data_type) {
544610 C_Status status;
545611 PD_MetaTensorSetDataType (raw_data (), data_type, &status);
0 commit comments