Skip to content

Commit c5ccff7

Browse files
authored
supports the slice of upper tensor, test=develop (#37215)
1 parent f49c2c2 commit c5ccff7

File tree

4 files changed

+35
-22
lines changed

4 files changed

+35
-22
lines changed

paddle/pten/core/compat_utils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,39 @@ class CompatibleDenseTensorUtils {
4545
static_cast<paddle::experimental::SharedStorage*>(tensor->storage_.get())
4646
->Reset();
4747
}
48+
49+
static DenseTensor Slice(DenseTensor* tensor,
50+
int64_t begin_idx,
51+
int64_t end_idx) {
52+
tensor->check_memory_size();
53+
PADDLE_ENFORCE_GE(begin_idx,
54+
0,
55+
paddle::platform::errors::OutOfRange(
56+
"The start row index must be greater than 0."
57+
"But received the start index is d%.",
58+
begin_idx));
59+
PADDLE_ENFORCE_LE(end_idx,
60+
tensor->dims()[0],
61+
paddle::platform::errors::OutOfRange(
62+
"The end row index is out of bound."));
63+
PADDLE_ENFORCE_LT(
64+
begin_idx,
65+
end_idx,
66+
paddle::platform::errors::InvalidArgument(
67+
"The start row index must be less than the end row index."
68+
"But received the start index = %d, the end index = %d.",
69+
begin_idx,
70+
end_idx));
71+
DenseTensor ret =
72+
DenseTensor(copy_intrusive(tensor->storage_), tensor->meta_);
73+
if (tensor->dims()[0] != 1) {
74+
ret.meta_.dims[0] = end_idx - begin_idx;
75+
ret.meta_.offset = tensor->meta_.offset +
76+
begin_idx * (tensor->numel() / tensor->dims()[0]) *
77+
paddle::experimental::SizeOf(tensor->data_type());
78+
}
79+
return ret;
80+
}
4881
};
4982

5083
} // namespace pten

paddle/pten/core/dense_tensor.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,6 @@ class DenseTensor : public TensorBase,
174174
/// \return The const data pointer value of raw type.
175175
const void* data() const;
176176

177-
/// \brief Get the shallow clone of current tensor.
178-
/// \return The shallow clone of current tensor.
179-
DenseTensor shallow_clone() const {
180-
return DenseTensor(copy_intrusive(storage_), meta_);
181-
}
182-
183177
private:
184178
friend class CompatibleDenseTensorUtils;
185179

paddle/pten/core/tensor_meta.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ struct DenseTensorMeta {
5757
const DataType type{DataType::UNDEFINED};
5858
const DataLayout layout{DataLayout::NCHW};
5959
LoD lod;
60+
size_t offset{0};
6061
};
6162

6263
inline DenseTensorMeta::DenseTensorMeta(DataType type, const DDim& dims)
@@ -86,7 +87,7 @@ inline bool operator==(const DenseTensorMeta& lhs, const DenseTensorMeta& rhs) {
8687
bool ret = true;
8788
return ret && (lhs.is_scalar == rhs.is_scalar) && (lhs.dims == rhs.dims) &&
8889
(lhs.type == rhs.type) && (lhs.layout == rhs.layout) &&
89-
(lhs.lod == rhs.lod);
90+
(lhs.lod == rhs.lod) && (lhs.offset == rhs.offset);
9091
}
9192

9293
} // namespace pten

paddle/pten/tests/core/test_dense_tensor.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,20 +125,5 @@ TEST(dense_tensor, resize) {
125125
CHECK_EQ(storage->size(), 6u);
126126
}
127127

128-
TEST(dense_tensor, shallow_clone) {
129-
const DDim dims({1, 2});
130-
const DataType dtype{DataType::INT8};
131-
const DataLayout layout{DataLayout::NHWC};
132-
const std::vector<std::vector<size_t>> lod{};
133-
DenseTensorMeta meta(dtype, dims, layout, lod);
134-
135-
auto alloc = std::make_shared<FancyAllocator>();
136-
DenseTensor tensor_0(alloc, meta);
137-
138-
auto tensor_1 = tensor_0.shallow_clone();
139-
CHECK(tensor_0.meta() == tensor_1.meta());
140-
CHECK(tensor_0.release() == tensor_1.release());
141-
}
142-
143128
} // namespace tests
144129
} // namespace pten

0 commit comments

Comments
 (0)