Skip to content

Commit 08a384f

Browse files
authored
[Other]Fix the fd tensor copy assignment (#506)
Fix the fd tensor copy assignment
1 parent 6408af2 commit 08a384f

2 files changed

Lines changed: 100 additions & 11 deletions

File tree

fastdeploy/core/fd_tensor.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ void FDTensor::Squeeze(int64_t axis) {
8989
size_t ndim = shape.size();
9090
FDASSERT(axis >= 0 && axis < ndim,
9191
"The allowed 'axis' must be in range of (0, %lu)!", ndim);
92-
FDASSERT(shape[axis]==1,
93-
"The No.%ld dimension of shape should be 1, but it is %ld!", (long)axis, (long)shape[axis]);
92+
FDASSERT(shape[axis] == 1,
93+
"The No.%ld dimension of shape should be 1, but it is %ld!",
94+
(long)axis, (long)shape[axis]);
9495
shape.erase(shape.begin() + axis);
9596
}
9697

@@ -220,9 +221,9 @@ bool FDTensor::ReallocFn(size_t nbytes) {
220221
return buffer_ != nullptr;
221222
#else
222223
FDASSERT(false,
223-
"The FastDeploy FDTensor allocator didn't compile under "
224-
"-DWITH_GPU=ON,"
225-
"so this is an unexpected problem happend.");
224+
"The FastDeploy FDTensor allocator didn't compile under "
225+
"-DWITH_GPU=ON,"
226+
"so this is an unexpected problem happend.");
226227
#endif
227228
}
228229
buffer_ = realloc(buffer_, nbytes);
@@ -316,16 +317,15 @@ FDTensor& FDTensor::operator=(const FDTensor& other) {
316317
if (other.buffer_ == nullptr) {
317318
FreeFn();
318319
buffer_ = nullptr;
320+
shape = other.shape;
321+
name = other.name;
322+
dtype = other.dtype;
323+
device = other.device;
319324
} else {
320-
Resize(other.shape);
325+
Resize(other.shape, other.dtype, other.name, other.device);
321326
size_t nbytes = Nbytes();
322327
CopyBuffer(buffer_, other.buffer_, nbytes);
323328
}
324-
325-
shape = other.shape;
326-
name = other.name;
327-
dtype = other.dtype;
328-
device = other.device;
329329
external_data_ptr = other.external_data_ptr;
330330
}
331331
return *this;

tests/core/test_fd_tensor.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <array>
16+
#include <cstring>
17+
#include <vector>
18+
#include "fastdeploy/core/fd_tensor.h"
19+
#include "gtest/gtest.h"
20+
#include "gtest_utils.h"
21+
22+
namespace fastdeploy {
23+
24+
TEST(fastdeploy, fd_tensor_constructor) {
25+
CheckShape check_shape;
26+
CheckData check_data;
27+
28+
FDTensor tensor1;
29+
check_shape(tensor1.shape, {0});
30+
ASSERT_EQ(tensor1.name, "");
31+
ASSERT_EQ(tensor1.dtype, FDDataType::INT8);
32+
ASSERT_EQ(tensor1.device, Device::CPU);
33+
34+
std::vector<int> inputs = {2, 4, 3, 7, 1, 5};
35+
tensor1.SetExternalData({2, 3}, FDDataType::INT32, inputs.data());
36+
ASSERT_EQ(tensor1.dtype, FDDataType::INT32);
37+
38+
FDTensor tensor2(tensor1);
39+
check_shape(tensor1.shape, {2, 3});
40+
ASSERT_EQ(tensor2.name, "");
41+
ASSERT_EQ(tensor2.dtype, FDDataType::INT32);
42+
ASSERT_EQ(tensor2.device, Device::CPU);
43+
44+
FDTensor tensor3;
45+
tensor3.Resize({2, 3}, FDDataType::INT32, "tensor3");
46+
check_shape(tensor3.shape, {2, 3});
47+
ASSERT_EQ(tensor3.Nbytes(), 24);
48+
49+
// Copy constructor
50+
FDTensor tensor4(tensor3);
51+
check_shape(tensor4.shape, {2, 3});
52+
ASSERT_EQ(tensor3.Nbytes(), tensor4.Nbytes());
53+
check_data(reinterpret_cast<int*>(tensor3.Data()),
54+
reinterpret_cast<int*>(tensor4.Data()), tensor4.Numel());
55+
56+
// Move constructor
57+
ASSERT_NE(tensor1.external_data_ptr, nullptr);
58+
FDTensor tensor5(std::move(tensor1));
59+
ASSERT_EQ(tensor1.external_data_ptr, nullptr);
60+
ASSERT_EQ(tensor5.external_data_ptr, inputs.data());
61+
check_shape(tensor5.shape, {2, 3});
62+
}
63+
64+
TEST(fastdeploy, fd_tensor_assignment) {
65+
CheckShape check_shape;
66+
CheckData check_data;
67+
68+
FDTensor tensor1("T1");
69+
std::vector<int> inputs = {2, 4, 3, 7, 1, 5};
70+
tensor1.SetExternalData({2, 3}, FDDataType::INT32, inputs.data());
71+
72+
FDTensor tensor2;
73+
tensor2 = tensor1;
74+
ASSERT_EQ(tensor2.name, "T1");
75+
ASSERT_EQ(tensor2.dtype, FDDataType::INT32);
76+
ASSERT_EQ(tensor2.device, Device::CPU);
77+
ASSERT_EQ(tensor2.Data(), inputs.data());
78+
check_shape(tensor2.shape, {2, 3});
79+
80+
FDTensor tensor3;
81+
tensor3 = std::move(tensor1);
82+
ASSERT_EQ(tensor3.name, "T1");
83+
ASSERT_EQ(tensor3.dtype, FDDataType::INT32);
84+
ASSERT_EQ(tensor3.device, Device::CPU);
85+
ASSERT_EQ(tensor3.Data(), inputs.data());
86+
ASSERT_EQ(tensor1.Data(), nullptr);
87+
}
88+
89+
} // namespace fastdeploy

0 commit comments

Comments
 (0)