@@ -822,6 +822,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
822822#else
823823 PADDLE_THROW (platform::errors::Unimplemented (
824824 " XPUPlace is not supported when not compiled with XPU" ));
825+ #endif
826+ } else if (platform::is_npu_place (tensor.place ())) {
827+ #ifdef PADDLE_WITH_ASCEND_CL
828+ constexpr size_t kBufSize = 1024 * 1024 * 64 ; // 64MB
829+ std::unique_ptr<char []> buf (new char [kBufSize ]);
830+ auto & npu_dev_ctx =
831+ static_cast <const platform::NPUDeviceContext&>(dev_ctx);
832+ platform::CPUPlace cpu;
833+ uintptr_t data = reinterpret_cast <uintptr_t >(data_ptr);
834+ while (size != 0 ) {
835+ size_t size_to_write = std::min (kBufSize , static_cast <size_t >(size));
836+ memory::Copy (cpu, buf.get (),
837+ BOOST_GET_CONST (platform::NPUPlace, tensor.place ()),
838+ reinterpret_cast <const void *>(data), size_to_write,
839+ npu_dev_ctx.stream ());
840+ npu_dev_ctx.Wait ();
841+ os.write (buf.get (), size_to_write);
842+ data += size_to_write;
843+ size -= size_to_write;
844+ }
845+ #else
846+ PADDLE_THROW (platform::errors::Unimplemented (
847+ " NPUPlace is not supported when not compiled with NPU" ));
825848#endif
826849 } else {
827850 os.write (static_cast <const char *>(data_ptr),
@@ -877,8 +900,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
877900 auto ctx = platform::CPUDeviceContext ();
878901 size_t size = tensor->numel () * framework::SizeOfType (desc.data_type ());
879902 if (platform::is_gpu_place (dev_ctx.GetPlace ()) ||
880- platform::is_xpu_place (dev_ctx.GetPlace ())) {
881- #if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU
903+ platform::is_xpu_place (dev_ctx.GetPlace ()) ||
904+ platform::is_npu_place (dev_ctx.GetPlace ())) {
905+ #if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \
906+ defined PADDLE_WITH_ASCEND_CL
882907 Tensor cpu_tensor;
883908 cpu_tensor.Resize (framework::make_ddim (shape));
884909 framework::VisitDataType (
@@ -891,9 +916,12 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
891916 if (platform::is_gpu_place (dev_ctx.GetPlace ())) {
892917 PADDLE_THROW (platform::errors::Unimplemented (
893918 " CUDAPlace is not supported when not compiled with CUDA" ));
894- } else {
919+ } else if ( platform::is_xpu_place (dev_ctx. GetPlace ())) {
895920 PADDLE_THROW (platform::errors::Unimplemented (
896921 " XPUPlace is not supported when not compiled with XPU" ));
922+ } else {
923+ PADDLE_THROW (platform::errors::Unimplemented (
924+ " NPUPlace is not supported when not compiled with NPU" ));
897925 }
898926#endif
899927 } else {
0 commit comments