Skip to content

Commit 8633603

Browse files
fix ReMakePtenDenseTensor place mismatch bug
1 parent 90b05d9 commit 8633603

File tree

1 file changed

+7
-47
lines changed

1 file changed

+7
-47
lines changed

paddle/pten/api/lib/utils/tensor_utils.cc

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,8 @@ std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
110110
} else {
111111
return MakePtenDenseTensor(tensor);
112112
}
113-
} else if (variable.IsType<framework::SelectedRows>()) {
114-
// TODO(chenweihang): now we don't deal with row and height
115-
// by xiaowei's advice
116-
const auto& tensor = variable.Get<framework::SelectedRows>();
117-
if (!platform::is_same_place(tensor.value().place(), expected_place)) {
118-
framework::Tensor tmp_tensor;
119-
TensorCopySync(tensor.value(), expected_place, &tmp_tensor);
120-
// TODO(chenweihang): adapt SelectedRows by xiaowei's design
121-
return MakePtenDenseTensor(tmp_tensor);
122-
} else {
123-
return MakePtenDenseTensor(tensor.value());
124-
}
125113
} else {
114+
// TODO(chentianyu03): support SelectedRows later
126115
PADDLE_THROW(platform::errors::Unimplemented(
127116
"Unsupported shared input `%s` type now when call pt kernel.",
128117
framework::ToTypeName(variable.Type())));
@@ -137,12 +126,8 @@ std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
137126
if (variable->template IsType<framework::LoDTensor>()) {
138127
auto* tensor = variable->template GetMutable<framework::LoDTensor>();
139128
return MakePtenDenseTensor(*tensor, arg_def);
140-
} else if (variable->template IsType<framework::SelectedRows>()) {
141-
auto* tensor = variable->template GetMutable<framework::SelectedRows>();
142-
// TODO(chenweihang): adapt SelectedRows by xiaowei's design,
143-
// here the row and height will lost in output!
144-
return MakePtenDenseTensor(tensor->value(), arg_def);
145129
} else {
130+
// TODO(chentianyu03): support SelectedRows later
146131
PADDLE_THROW(platform::errors::Unimplemented(
147132
"Unsupported shared output `%s` type now when call pt kernel.",
148133
framework::ToTypeName(variable->Type())));
@@ -220,7 +205,8 @@ void ReMakePtenDenseTensor(const paddle::framework::LoDTensor& src,
220205
shared_storage,
221206
platform::errors::NotFound(
222207
"Target DenseTensor's shared storage is nullptr."));
223-
if (src.IsInitialized()) {
208+
if (src.IsInitialized() &&
209+
src.place() == pten::TransToFluidPlace(arg_def.backend)) {
224210
shared_storage->ResetAllocation(src.Holder(), src.offset());
225211
} else {
226212
shared_storage->ResetAllocationPlace(
@@ -242,19 +228,8 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
242228
} else {
243229
ReMakePtenDenseTensor(tensor, arg_def, dst);
244230
}
245-
} else if (variable.IsType<framework::SelectedRows>()) {
246-
// TODO(chenweihang): now we don't deal with row and height
247-
// by xiaowei's advice
248-
const auto& tensor = variable.Get<framework::SelectedRows>();
249-
if (!platform::is_same_place(tensor.value().place(), expected_place)) {
250-
framework::Tensor tmp_tensor;
251-
TensorCopySync(tensor.value(), expected_place, &tmp_tensor);
252-
// TODO(chenweihang): adapt SelectedRows by xiaowei's design
253-
ReMakePtenDenseTensor(tmp_tensor, arg_def, dst);
254-
} else {
255-
ReMakePtenDenseTensor(tensor.value(), arg_def, dst);
256-
}
257231
} else {
232+
// TODO(chentianyu03): support SelectedRows later
258233
PADDLE_THROW(platform::errors::Unimplemented(
259234
"Unsupported shared input `%s` type now when call pt kernel.",
260235
framework::ToTypeName(variable.Type())));
@@ -269,12 +244,8 @@ void ReMakePtenDenseTensorFromVar(framework::Variable* variable,
269244
if (variable->template IsType<framework::LoDTensor>()) {
270245
auto* tensor = variable->template GetMutable<framework::LoDTensor>();
271246
ReMakePtenDenseTensor(*tensor, arg_def, dst);
272-
} else if (variable->template IsType<framework::SelectedRows>()) {
273-
auto* tensor = variable->template GetMutable<framework::SelectedRows>();
274-
// TODO(chenweihang): adapt SelectedRows by xiaowei's design,
275-
// here the row and height will lost in output!
276-
ReMakePtenDenseTensor(tensor->value(), arg_def, dst);
277247
} else {
248+
// TODO(chentianyu03): support SelectedRows later
278249
PADDLE_THROW(platform::errors::Unimplemented(
279250
"Unsupported shared output `%s` type now when call pt kernel.",
280251
framework::ToTypeName(variable->Type())));
@@ -311,19 +282,8 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src,
311282
// so, here we set the variable's type with the pten tensor dtype.
312283
tensor->setType(dtype);
313284
}
314-
315-
} else if (variable->IsType<framework::SelectedRows>()) {
316-
auto* tensor = variable->GetMutable<framework::SelectedRows>();
317-
auto dtype = pten::TransToProtoVarType(src->dtype());
318-
319-
if (tensor->value().IsInitialized()) {
320-
} else {
321-
auto storage = dynamic_cast<SharedStorage*>(
322-
pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage(src));
323-
tensor->mutable_value()->ResetHolderWithType(
324-
std::move(storage->GetAllocation()), dtype);
325-
}
326285
} else {
286+
// TODO(chentianyu03): support SelectedRows later
327287
PADDLE_THROW(platform::errors::Unimplemented(
328288
"Unsupported shared input `%s` type now when call pt kernel.",
329289
framework::ToTypeName(variable->Type())));

0 commit comments

Comments
 (0)