@@ -110,19 +110,8 @@ std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
110110 } else {
111111 return MakePtenDenseTensor (tensor);
112112 }
113- } else if (variable.IsType <framework::SelectedRows>()) {
114- // TODO(chenweihang): now we don't deal with row and height
115- // by xiaowei's advice
116- const auto & tensor = variable.Get <framework::SelectedRows>();
117- if (!platform::is_same_place (tensor.value ().place (), expected_place)) {
118- framework::Tensor tmp_tensor;
119- TensorCopySync (tensor.value (), expected_place, &tmp_tensor);
120- // TODO(chenweihang): adapt SelectedRows by xiaowei's design
121- return MakePtenDenseTensor (tmp_tensor);
122- } else {
123- return MakePtenDenseTensor (tensor.value ());
124- }
125113 } else {
114+ // TODO(chentianyu03): support SelectedRows later
126115 PADDLE_THROW (platform::errors::Unimplemented (
127116 " Unsupported shared input `%s` type now when call pt kernel." ,
128117 framework::ToTypeName (variable.Type ())));
@@ -137,12 +126,8 @@ std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
137126 if (variable->template IsType <framework::LoDTensor>()) {
138127 auto * tensor = variable->template GetMutable <framework::LoDTensor>();
139128 return MakePtenDenseTensor (*tensor, arg_def);
140- } else if (variable->template IsType <framework::SelectedRows>()) {
141- auto * tensor = variable->template GetMutable <framework::SelectedRows>();
142- // TODO(chenweihang): adapt SelectedRows by xiaowei's design,
143- // here the row and height will lost in output!
144- return MakePtenDenseTensor (tensor->value (), arg_def);
145129 } else {
130+ // TODO(chentianyu03): support SelectedRows later
146131 PADDLE_THROW (platform::errors::Unimplemented (
147132 " Unsupported shared output `%s` type now when call pt kernel." ,
148133 framework::ToTypeName (variable->Type ())));
@@ -220,7 +205,8 @@ void ReMakePtenDenseTensor(const paddle::framework::LoDTensor& src,
220205 shared_storage,
221206 platform::errors::NotFound (
222207 " Target DenseTensor's shared storage is nullptr." ));
223- if (src.IsInitialized ()) {
208+ if (src.IsInitialized () &&
209+ src.place () == pten::TransToFluidPlace (arg_def.backend )) {
224210 shared_storage->ResetAllocation (src.Holder (), src.offset ());
225211 } else {
226212 shared_storage->ResetAllocationPlace (
@@ -242,19 +228,8 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
242228 } else {
243229 ReMakePtenDenseTensor (tensor, arg_def, dst);
244230 }
245- } else if (variable.IsType <framework::SelectedRows>()) {
246- // TODO(chenweihang): now we don't deal with row and height
247- // by xiaowei's advice
248- const auto & tensor = variable.Get <framework::SelectedRows>();
249- if (!platform::is_same_place (tensor.value ().place (), expected_place)) {
250- framework::Tensor tmp_tensor;
251- TensorCopySync (tensor.value (), expected_place, &tmp_tensor);
252- // TODO(chenweihang): adapt SelectedRows by xiaowei's design
253- ReMakePtenDenseTensor (tmp_tensor, arg_def, dst);
254- } else {
255- ReMakePtenDenseTensor (tensor.value (), arg_def, dst);
256- }
257231 } else {
232+ // TODO(chentianyu03): support SelectedRows later
258233 PADDLE_THROW (platform::errors::Unimplemented (
259234 " Unsupported shared input `%s` type now when call pt kernel." ,
260235 framework::ToTypeName (variable.Type ())));
@@ -269,12 +244,8 @@ void ReMakePtenDenseTensorFromVar(framework::Variable* variable,
269244 if (variable->template IsType <framework::LoDTensor>()) {
270245 auto * tensor = variable->template GetMutable <framework::LoDTensor>();
271246 ReMakePtenDenseTensor (*tensor, arg_def, dst);
272- } else if (variable->template IsType <framework::SelectedRows>()) {
273- auto * tensor = variable->template GetMutable <framework::SelectedRows>();
274- // TODO(chenweihang): adapt SelectedRows by xiaowei's design,
275- // here the row and height will lost in output!
276- ReMakePtenDenseTensor (tensor->value (), arg_def, dst);
277247 } else {
248+ // TODO(chentianyu03): support SelectedRows later
278249 PADDLE_THROW (platform::errors::Unimplemented (
279250 " Unsupported shared output `%s` type now when call pt kernel." ,
280251 framework::ToTypeName (variable->Type ())));
@@ -311,19 +282,8 @@ void MakeVariableFromPtenTensor(pten::DenseTensor* src,
311282 // so, here we set the variable's type with the pten tensor dtype.
312283 tensor->setType (dtype);
313284 }
314-
315- } else if (variable->IsType <framework::SelectedRows>()) {
316- auto * tensor = variable->GetMutable <framework::SelectedRows>();
317- auto dtype = pten::TransToProtoVarType (src->dtype ());
318-
319- if (tensor->value ().IsInitialized ()) {
320- } else {
321- auto storage = dynamic_cast <SharedStorage*>(
322- pten::CompatibleDenseTensorUtils::UnsafeGetMutableStorage (src));
323- tensor->mutable_value ()->ResetHolderWithType (
324- std::move (storage->GetAllocation ()), dtype);
325- }
326285 } else {
286+ // TODO(chentianyu03): support SelectedRows later
327287 PADDLE_THROW (platform::errors::Unimplemented (
328288 " Unsupported shared input `%s` type now when call pt kernel." ,
329289 framework::ToTypeName (variable->Type ())));
0 commit comments