@@ -270,7 +270,23 @@ inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
270270 }
271271}
272272
273- inline void DeserializeFromStream (std::istream& is, Tensor* tensor) {
273+ struct DeserializedDataFunctor {
274+ DeserializedDataFunctor (void ** buf, Tensor* tensor,
275+ const platform::Place& place)
276+ : buf_(buf), tensor_(tensor), place_(place) {}
277+
278+ template <typename T>
279+ void operator ()() {
280+ *buf_ = tensor_->mutable_data <T>(place_);
281+ }
282+
283+ void ** buf_;
284+ Tensor* tensor_;
285+ platform::Place place_;
286+ };
287+
288+ inline void DeserializeFromStream (std::istream& is, Tensor* tensor,
289+ const platform::DeviceContext& dev_ctx) {
274290 uint32_t version;
275291 is.read (reinterpret_cast <char *>(&version), sizeof (version));
276292 PADDLE_ENFORCE_EQ (version, 0U , " Only version 0 is supported" );
@@ -289,27 +305,28 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor) {
289305 dims.reserve (static_cast <size_t >(desc.dims ().size ()));
290306 std::copy (desc.dims ().begin (), desc.dims ().end (), std::back_inserter (dims));
291307 tensor->Resize (framework::make_ddim (dims));
292-
293308 void * buf;
294- platform::Place cpu = platform::CPUPlace ();
295- // TODO(Yancey1989): use VisiterDataType instead of DataType switch
296- switch (desc.data_type ()) {
297- case proto::FP32:
298- buf = tensor->mutable_data <float >(cpu);
299- break ;
300- case proto::FP64:
301- buf = tensor->mutable_data <double >(cpu);
302- break ;
303- case proto::INT32:
304- buf = tensor->mutable_data <int >(cpu);
305- break ;
306- case proto::INT64:
307- buf = tensor->mutable_data <int64_t >(cpu);
308- break ;
309- default :
310- PADDLE_THROW (" DataType %d not supported" , desc.data_type ());
309+ auto ctx = platform::CPUDeviceContext ();
310+ if (platform::is_gpu_place (dev_ctx.GetPlace ())) {
311+ #ifdef PADDLE_WITH_CUDA
312+ Tensor cpu_tensor;
313+ cpu_tensor.Resize (framework::make_ddim (dims));
314+ framework::VisitDataType (
315+ desc.data_type (),
316+ DeserializedDataFunctor (&buf, &cpu_tensor, ctx.GetPlace ()));
317+ is.read (static_cast <char *>(buf), cpu_tensor.memory_size ());
318+ auto cpu_place = new platform::CPUPlace ();
319+ framework::CopyFrom (cpu_tensor, *cpu_place, dev_ctx, tensor);
320+ delete cpu_place;
321+ #else
322+ PADDLE_THROW (" Unexpected branch" );
323+ #endif
324+ } else {
325+ framework::VisitDataType (
326+ desc.data_type (),
327+ DeserializedDataFunctor (&buf, tensor, ctx.GetPlace ()));
328+ is.read (static_cast <char *>(buf), tensor->memory_size ());
311329 }
312- is.read (static_cast <char *>(buf), tensor->memory_size ());
313330 }
314331}
315332
0 commit comments