Skip to content

Commit b7a54fc

Browse files
authored
support convert core.Tensor to paddle.Tensor (#33430)
1 parent e47c3f0 commit b7a54fc

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

paddle/fluid/pybind/imperative.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
245245
}
246246

247247
static void InitVarBaseFromTensorWithArgDefault(
248-
imperative::VarBase *self, const framework::LoDTensor &tensor) {
248+
imperative::VarBase *self, const framework::Tensor &tensor) {
249249
VLOG(4) << "Init VarBase";
250250
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
251251
new (self) imperative::VarBase(

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def _test_place(place):
176176

177177
x = paddle.to_tensor(1, dtype='uint8')
178178
self.assertEqual(x.item(), 1)
179-
print(type(x.item()))
180179
self.assertTrue(isinstance(x.item(), int))
181180

182181
x = paddle.to_tensor(1, dtype='int8')
@@ -203,6 +202,24 @@ def _test_place(place):
203202
self.assertEqual(x.item(), 1 + 1j)
204203
self.assertTrue(isinstance(x.item(), complex))
205204

205+
numpy_array = np.random.randn(3, 4)
206+
# covert core.LoDTensor to paddle.Tensor
207+
lod_tensor = paddle.fluid.core.LoDTensor()
208+
place = paddle.fluid.framework._current_expected_place()
209+
lod_tensor.set(numpy_array, place)
210+
x = paddle.to_tensor(lod_tensor)
211+
self.assertTrue(np.array_equal(x.numpy(), numpy_array))
212+
self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR)
213+
self.assertEqual(str(x.place), str(place))
214+
215+
# covert core.Tensor to paddle.Tensor
216+
x = paddle.to_tensor(numpy_array)
217+
dlpack = x.value().get_tensor()._to_dlpack()
218+
tensor_from_dlpack = paddle.fluid.core.from_dlpack(dlpack)
219+
x = paddle.to_tensor(tensor_from_dlpack)
220+
self.assertTrue(np.array_equal(x.numpy(), numpy_array))
221+
self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR)
222+
206223
with self.assertRaises(ValueError):
207224
paddle.randn([3, 2, 2]).item()
208225
with self.assertRaises(ValueError):

python/paddle/tensor/creation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ def _handle_dtype(data, dtype):
136136
data = data._copy_to(place, False)
137137
ata = _handle_dtype(data, dtype)
138138
data.stop_gradient = stop_gradient
139-
elif isinstance(data, core.LoDTensor):
140-
# convert LoDTensor to VarBase first
141-
# Currenly, LoDTensor does no copy when places are same
139+
elif isinstance(data, (core.LoDTensor, core.Tensor)):
140+
# Note(zhouwei25): should't expose it to users, just for internal use.
141+
# convert core.Tensor/core.LoDTensor to VarBase first
142+
# Currenly, there is no copy when places are same
142143
data = paddle.Tensor(data)
143144
if not data.place._equals(place):
144145
data = data._copy_to(place, False)

0 commit comments

Comments
 (0)