@@ -822,6 +822,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
822822#else
823823 PADDLE_THROW (platform::errors::Unimplemented (
824824 " XPUPlace is not supported when not compiled with XPU" ));
825+ #endif
826+ } else if (platform::is_npu_place (tensor.place ())) {
827+ #ifdef PADDLE_WITH_ASCEND_CL
828+ constexpr size_t kBufSize = 1024 * 1024 * 64 ; // 64MB
829+ std::unique_ptr<char []> buf (new char [kBufSize ]);
830+ auto & npu_dev_ctx =
831+ static_cast <const platform::NPUDeviceContext&>(dev_ctx);
832+ platform::CPUPlace cpu;
833+ uintptr_t data = reinterpret_cast <uintptr_t >(data_ptr);
834+ while (size != 0 ) {
835+ size_t size_to_write = std::min (kBufSize , static_cast <size_t >(size));
836+ memory::Copy (cpu, buf.get (),
837+ BOOST_GET_CONST (platform::NPUPlace, tensor.place ()),
838+ reinterpret_cast <const void *>(data), size_to_write,
839+ npu_dev_ctx.stream ());
840+ npu_dev_ctx.Wait ();
841+ os.write (buf.get (), size_to_write);
842+ data += size_to_write;
843+ size -= size_to_write;
844+ }
845+ #else
846+ PADDLE_THROW (platform::errors::Unimplemented (
847+ " NPUPlace is not supported when not compiled with NPU" ));
825848#endif
826849 } else {
827850 os.write (static_cast <const char *>(data_ptr),
@@ -877,8 +900,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
877900 auto ctx = platform::CPUDeviceContext ();
878901 size_t size = tensor->numel () * framework::SizeOfType (desc.data_type ());
879902 if (platform::is_gpu_place (dev_ctx.GetPlace ()) ||
880- platform::is_xpu_place (dev_ctx.GetPlace ())) {
881- #if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU
903+ platform::is_xpu_place (dev_ctx.GetPlace ()) ||
904+ platform::is_npu_place (dev_ctx.GetPlace ())) {
905+ #if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \
906+ defined PADDLE_WITH_ASCEND_CL
882907 Tensor cpu_tensor;
883908 cpu_tensor.Resize (framework::make_ddim (shape));
884909 framework::VisitDataType (
@@ -887,13 +912,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
887912 is.read (static_cast <char *>(buf), size);
888913 auto dst_place = dev_ctx.GetPlace ();
889914 framework::TensorCopy (cpu_tensor, dst_place, dev_ctx, tensor);
915+ if (platform::is_npu_place (dev_ctx.GetPlace ())) {
916+ dev_ctx.Wait ();
917+ }
890918#else
891919 if (platform::is_gpu_place (dev_ctx.GetPlace ())) {
892920 PADDLE_THROW (platform::errors::Unimplemented (
893921 " CUDAPlace is not supported when not compiled with CUDA" ));
894- } else {
922+ } else if ( platform::is_xpu_place (dev_ctx. GetPlace ())) {
895923 PADDLE_THROW (platform::errors::Unimplemented (
896924 " XPUPlace is not supported when not compiled with XPU" ));
925+ } else {
926+ PADDLE_THROW (platform::errors::Unimplemented (
927+ " NPUPlace is not supported when not compiled with NPU" ));
897928 }
898929#endif
899930 } else {
@@ -934,8 +965,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
934965 auto ctx = platform::CPUDeviceContext ();
935966 size_t size = tensor->numel () * framework::SizeOfType (desc.data_type ());
936967 if (platform::is_gpu_place (dev_ctx.GetPlace ()) ||
937- platform::is_xpu_place (dev_ctx.GetPlace ())) {
938- #if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU
968+ platform::is_xpu_place (dev_ctx.GetPlace ()) ||
969+ platform::is_npu_place (dev_ctx.GetPlace ())) {
970+ #if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \
971+ defined PADDLE_WITH_ASCEND_CL
939972 Tensor cpu_tensor;
940973 cpu_tensor.Resize (framework::make_ddim (dims));
941974 framework::VisitDataType (
@@ -944,13 +977,19 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
944977 is.read (static_cast <char *>(buf), size);
945978 auto dst_place = dev_ctx.GetPlace ();
946979 framework::TensorCopy (cpu_tensor, dst_place, dev_ctx, tensor);
980+ if (platform::is_npu_place (dev_ctx.GetPlace ())) {
981+ dev_ctx.Wait ();
982+ }
947983#else
948984 if (platform::is_gpu_place (dev_ctx.GetPlace ())) {
949985 PADDLE_THROW (platform::errors::Unimplemented (
950986 " CUDAPlace is not supported when not compiled with CUDA" ));
951- } else {
987+ } else if ( platform::is_xpu_place (dev_ctx. GetPlace ())) {
952988 PADDLE_THROW (platform::errors::Unimplemented (
953989 " XPUPlace is not supported when not compiled with XPU" ));
990+ } else {
991+ PADDLE_THROW (platform::errors::Unimplemented (
992+ " NPUPlace is not supported when not compiled with NPU" ));
954993 }
955994#endif
956995 } else {
0 commit comments