@@ -277,32 +277,73 @@ std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
277277}
278278
279279void VarBase::CopyFrom (const VarBase& src, const bool blocking) {
280- if (SharedVar ()->IsEmpty ()) {
281- VLOG (3 ) << " deep copy Variable from " << src.Name () << " to " << Name ();
282- SetPersistable (src.Persistable ());
280+ if (src.SharedVar ()->IsEmpty ()) {
281+ return ;
282+ }
283+
284+ VLOG (3 ) << " Deep copy Tensor from " << src.Name () << " to " << Name ();
285+ if (Var ().IsInitialized ()) {
286+ PADDLE_ENFORCE_EQ (DataType (), src.DataType (),
287+ platform::errors::PreconditionNotMet (
288+ " Tensor %s has different data type with Tensor %s, "
289+ " Tensor Copy cannot be performed!" ,
290+ Name (), src.Name ()));
291+ PADDLE_ENFORCE_EQ (Type (), src.Type (),
292+ platform::errors::PreconditionNotMet (
293+ " Tensor %s has different type with Tensor %s, Tensor "
294+ " Copy cannot be performed!" ,
295+ Name (), src.Name ()));
296+ } else {
283297 SetDataType (src.DataType ());
284298 SetType (src.Type ());
285- SetOverridedStopGradient (src.OverridedStopGradient ());
286- if (!src.SharedVar ()->IsEmpty ()) {
287- const platform::Place& place = src.Place ();
288- if (src.Var ().IsType <framework::LoDTensor>()) {
289- auto & src_tensor = src.Var ().Get <framework::LoDTensor>();
290- auto * dst_tensor = MutableVar ()->GetMutable <framework::LoDTensor>();
291- dst_tensor->set_lod (src_tensor.lod ());
292- framework::TensorCopy (src_tensor, place, dst_tensor);
293- } else if (src.Var ().IsType <framework::SelectedRows>()) {
294- auto & src_selected_rows = src.Var ().Get <framework::SelectedRows>();
295- auto * dst_selected_rows =
296- MutableVar ()->GetMutable <framework::SelectedRows>();
297- dst_selected_rows->set_height (src_selected_rows.height ());
298- dst_selected_rows->set_rows (src_selected_rows.rows ());
299- framework::TensorCopy (src_selected_rows.value (), place,
300- dst_selected_rows->mutable_value ());
301- }
302- if (blocking) {
303- platform::DeviceContextPool::Instance ().Get (place)->Wait ();
304- }
299+ SetPersistable (src.Persistable ());
300+ InnerSetOverridedStopGradient (src.OverridedStopGradient ());
301+ }
302+
303+ platform::Place place = src.Place ();
304+ if (src.Var ().IsType <framework::LoDTensor>()) {
305+ auto & src_tensor = src.Var ().Get <framework::LoDTensor>();
306+ auto * dst_tensor = MutableVar ()->GetMutable <framework::LoDTensor>();
307+ if (dst_tensor && dst_tensor->IsInitialized ()) {
308+ PADDLE_ENFORCE_EQ (dst_tensor->dims (), src_tensor.dims (),
309+ platform::errors::PreconditionNotMet (
310+ " Tensor %s has different dims with Tensor %s, "
311+ " Tensor Copy cannot be performed!" ,
312+ Name (), src.Name ()));
313+ PADDLE_ENFORCE_EQ (dst_tensor->lod (), src_tensor.lod (),
314+ platform::errors::PreconditionNotMet (
315+ " Tensor %s has different dims with Tensor %s, "
316+ " Tensor Copy cannot be performed!" ,
317+ Name (), src.Name ()));
318+ place = Place ();
319+ } else {
320+ dst_tensor->set_lod (src_tensor.lod ());
321+ dst_tensor->Resize (src_tensor.dims ());
322+ }
323+ framework::TensorCopy (src_tensor, place, dst_tensor);
324+ } else if (src.Var ().IsType <framework::SelectedRows>()) {
325+ auto & src_selected_rows = src.Var ().Get <framework::SelectedRows>();
326+ auto * dst_selected_rows =
327+ MutableVar ()->GetMutable <framework::SelectedRows>();
328+ dst_selected_rows->set_height (src_selected_rows.height ());
329+ dst_selected_rows->set_rows (src_selected_rows.rows ());
330+
331+ auto & src_tensor = src_selected_rows.value ();
332+ auto * dst_tensor = dst_selected_rows->mutable_value ();
333+ if (dst_tensor && dst_tensor->IsInitialized ()) {
334+ PADDLE_ENFORCE_EQ (dst_tensor->dims (), src_tensor.dims (),
335+ platform::errors::PreconditionNotMet (
336+ " Tensor %s has different dims with Tensor %s, "
337+ " Tensor Copy cannot be performed!" ,
338+ Name (), src.Name ()));
339+ place = Place ();
340+ } else {
341+ dst_tensor->Resize (src_tensor.dims ());
305342 }
343+ framework::TensorCopy (src_tensor, place, dst_tensor);
344+ }
345+ if (blocking) {
346+ platform::DeviceContextPool::Instance ().Get (place)->Wait ();
306347 }
307348}
308349
0 commit comments