Skip to content

Commit ea9d783

Browse files
committed
support save load for NPU
1 parent 3c66b87 commit ea9d783

File tree

5 files changed

+140
-3
lines changed

5 files changed

+140
-3
lines changed

paddle/fluid/framework/tensor_util.cc

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
822822
#else
823823
PADDLE_THROW(platform::errors::Unimplemented(
824824
"XPUPlace is not supported when not compiled with XPU"));
825+
#endif
826+
} else if (platform::is_npu_place(tensor.place())) {
827+
#ifdef PADDLE_WITH_ASCEND_CL
828+
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
829+
std::unique_ptr<char[]> buf(new char[kBufSize]);
830+
auto& npu_dev_ctx =
831+
static_cast<const platform::NPUDeviceContext&>(dev_ctx);
832+
platform::CPUPlace cpu;
833+
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
834+
while (size != 0) {
835+
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
836+
memory::Copy(cpu, buf.get(),
837+
BOOST_GET_CONST(platform::NPUPlace, tensor.place()),
838+
reinterpret_cast<const void*>(data), size_to_write,
839+
npu_dev_ctx.stream());
840+
npu_dev_ctx.Wait();
841+
os.write(buf.get(), size_to_write);
842+
data += size_to_write;
843+
size -= size_to_write;
844+
}
845+
#else
846+
PADDLE_THROW(platform::errors::Unimplemented(
847+
"NPUPlace is not supported when not compiled with NPU"));
825848
#endif
826849
} else {
827850
os.write(static_cast<const char*>(data_ptr),
@@ -877,8 +900,10 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
877900
auto ctx = platform::CPUDeviceContext();
878901
size_t size = tensor->numel() * framework::SizeOfType(desc.data_type());
879902
if (platform::is_gpu_place(dev_ctx.GetPlace()) ||
880-
platform::is_xpu_place(dev_ctx.GetPlace())) {
881-
#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU
903+
platform::is_xpu_place(dev_ctx.GetPlace()) ||
904+
platform::is_npu_place(dev_ctx.GetPlace())) {
905+
#if defined PADDLE_WITH_CUDA || defined PADDLE_WITH_XPU || \
906+
defined PADDLE_WITH_ASCEND_CL
882907
Tensor cpu_tensor;
883908
cpu_tensor.Resize(framework::make_ddim(shape));
884909
framework::VisitDataType(
@@ -891,9 +916,12 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
891916
if (platform::is_gpu_place(dev_ctx.GetPlace())) {
892917
PADDLE_THROW(platform::errors::Unimplemented(
893918
"CUDAPlace is not supported when not compiled with CUDA"));
894-
} else {
919+
} else if (platform::is_xpu_place(dev_ctx.GetPlace())) {
895920
PADDLE_THROW(platform::errors::Unimplemented(
896921
"XPUPlace is not supported when not compiled with XPU"));
922+
} else {
923+
PADDLE_THROW(platform::errors::Unimplemented(
924+
"NPUPlace is not supported when not compiled with NPU"));
897925
}
898926
#endif
899927
} else {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_ASCEND_CL
16+
#include "paddle/fluid/operators/load_combine_op.h"
17+
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_NPU_KERNEL(
21+
load_combine,
22+
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, float>,
23+
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, double>,
24+
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, int>,
25+
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, int8_t>,
26+
ops::LoadCombineOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
27+
#endif
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_ASCEND_CL
16+
#include "paddle/fluid/operators/load_op.h"
17+
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_NPU_KERNEL(
21+
load, ops::LoadOpKernel<paddle::platform::NPUDeviceContext, float>,
22+
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, double>,
23+
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, int>,
24+
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, int8_t>,
25+
ops::LoadOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
26+
#endif
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_ASCEND_CL
16+
#include "paddle/fluid/operators/save_combine_op.h"
17+
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_NPU_KERNEL(
21+
save_combine,
22+
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, float>,
23+
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, double>,
24+
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, int>,
25+
ops::SaveCombineOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
26+
#endif
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_ASCEND_CL
16+
#include "paddle/fluid/operators/save_op.h"
17+
#include "paddle/fluid/platform/float16.h"
18+
19+
namespace ops = paddle::operators;
20+
21+
REGISTER_OP_NPU_KERNEL(
22+
save, ops::SaveOpKernel<paddle::platform::NPUDeviceContext, float>,
23+
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, double>,
24+
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, int>,
25+
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, uint8_t>,
26+
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, int8_t>,
27+
ops::SaveOpKernel<paddle::platform::NPUDeviceContext, int64_t>,
28+
ops::SaveOpKernel<paddle::platform::NPUDeviceContext,
29+
paddle::platform::float16>);
30+
#endif

0 commit comments

Comments
 (0)