Skip to content

Commit 24cd730

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into log_softmax
2 parents 0c1aec6 + ec2ffb6 commit 24cd730

File tree

108 files changed

+4007
-285
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+4007
-285
lines changed

CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
3333
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
3434
option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF)
3535
option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF)
36+
# NOTE(zhiqiu): WITH_ASCEND_CL can be compile on x86_64, so we can set WITH_ASCEND=OFF and WITH_ASCEND_CL=ON
37+
# to develop some acl related functionality on x86
38+
option(WITH_ASCEND_CL "Compile PaddlePaddle with ASCEND CL" ${WITH_ASCEND})
3639
option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF)
3740
if (WITH_GPU AND WITH_XPU)
3841
message(FATAL_ERROR "Error when compile GPU and XPU at the same time")
3942
endif()
40-
if (WITH_GPU AND WITH_ASCEND)
43+
if (WITH_GPU AND WITH_ASCEND)
4144
message(FATAL_ERROR "Error when compile GPU and ASCEND at the same time")
4245
endif()
4346

cmake/configure.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ if(WITH_ASCEND)
8282
add_definitions(-DPADDLE_WITH_ASCEND)
8383
endif()
8484

85+
if(WITH_ASCEND_CL)
86+
add_definitions(-DPADDLE_WITH_ASCEND_CL)
87+
endif()
88+
8589
if(WITH_XPU)
8690
message(STATUS "Compile with XPU!")
8791
add_definitions(-DPADDLE_WITH_XPU)

cmake/external/ascend.cmake

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,60 @@ else()
2121
set(ASCEND_DIR /usr/local/Ascend)
2222
endif()
2323

24-
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
25-
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
26-
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share)
27-
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64)
28-
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
29-
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
30-
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})
31-
32-
set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR})
33-
set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR})
34-
set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
35-
set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
36-
set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64)
37-
set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64)
38-
set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR})
39-
40-
set(atlas_graph_lib ${ATLAS_RUNTIME_DIR}/libgraph.so)
41-
set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so)
42-
set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so)
43-
INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})
44-
45-
if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h)
46-
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
24+
if(WITH_ASCEND)
25+
set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
26+
set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
27+
set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share)
28+
set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64)
29+
set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
30+
set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
31+
set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})
32+
33+
set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR})
34+
set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR})
35+
set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
36+
set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
37+
set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64)
38+
set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64)
39+
set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR})
40+
41+
set(atlas_graph_lib ${ATLAS_RUNTIME_DIR}/libgraph.so)
42+
set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so)
43+
set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so)
44+
INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})
45+
46+
if(EXISTS ${ATLAS_RUNTIME_INC_DIR}/graph/ascend_string.h)
47+
add_definitions(-DPADDLE_WITH_ASCEND_STRING)
48+
endif()
49+
50+
ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
51+
SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib})
52+
53+
ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL)
54+
SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph_lib})
55+
56+
ADD_LIBRARY(atlas_acl SHARED IMPORTED GLOBAL)
57+
SET_PROPERTY(TARGET atlas_acl PROPERTY IMPORTED_LOCATION ${atlas_acl_lib})
58+
59+
add_custom_target(extern_ascend DEPENDS ascend_ge ascend_graph atlas_acl)
4760
endif()
4861

49-
ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
50-
SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib})
62+
if(WITH_ASCEND_CL)
63+
set(ASCEND_CL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
64+
65+
set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so)
66+
set(acl_op_compiler_lib ${ASCEND_CL_DIR}/libacl_op_compiler.so)
67+
set(ASCEND_CL_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
5168

52-
ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL)
53-
SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph_lib})
69+
message(STATUS "ASCEND_CL_INC_DIR ${ASCEND_CL_INC_DIR}")
70+
message(STATUS "ASCEND_CL_DIR ${ASCEND_CL_DIR}")
71+
INCLUDE_DIRECTORIES(${ASCEND_CL_INC_DIR})
5472

55-
ADD_LIBRARY(atlas_acl SHARED IMPORTED GLOBAL)
56-
SET_PROPERTY(TARGET atlas_acl PROPERTY IMPORTED_LOCATION ${atlas_acl_lib})
73+
ADD_LIBRARY(ascendcl SHARED IMPORTED GLOBAL)
74+
SET_PROPERTY(TARGET ascendcl PROPERTY IMPORTED_LOCATION ${ascendcl_lib})
5775

58-
add_custom_target(extern_ascend DEPENDS ascend_ge ascend_graph atlas_acl)
76+
ADD_LIBRARY(acl_op_compiler SHARED IMPORTED GLOBAL)
77+
SET_PROPERTY(TARGET acl_op_compiler PROPERTY IMPORTED_LOCATION ${acl_op_compiler_lib})
78+
add_custom_target(extern_ascend_cl DEPENDS ascendcl acl_op_compiler)
79+
80+
endif()

cmake/external/protobuf.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
201201
if(WITH_ASCEND AND NOT WITH_ASCEND_CXX11)
202202
SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
203203
SET(PROTOBUF_TAG v3.8.0)
204+
elseif(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11)
205+
SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
206+
SET(PROTOBUF_TAG v3.8.0)
204207
else()
205208
SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
206209
SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)

cmake/operators.cmake

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function(op_library TARGET)
1111
set(cu_cc_srcs)
1212
set(hip_cc_srcs)
1313
set(xpu_cc_srcs)
14+
set(npu_cc_srcs)
1415
set(cudnn_cu_cc_srcs)
1516
set(miopen_cu_cc_srcs)
1617
set(cudnn_cu_srcs)
@@ -20,6 +21,9 @@ function(op_library TARGET)
2021
set(mkldnn_cc_srcs)
2122
set(MKLDNN_FILE)
2223
set(op_common_deps operator op_registry math_function layer common_infer_shape_functions)
24+
if (WITH_ASCEND_CL)
25+
set(op_common_deps ${op_common_deps} npu_op_runner)
26+
endif()
2327
# Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build.
2428
set(options UNITY)
2529
set(oneValueArgs "")
@@ -85,6 +89,12 @@ function(op_library TARGET)
8589
list(APPEND xpu_cc_srcs ${XPU_FILE}.cc)
8690
endif()
8791
endif()
92+
if(WITH_ASCEND_CL)
93+
string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
94+
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc)
95+
list(APPEND npu_cc_srcs ${NPU_FILE}.cc)
96+
endif()
97+
endif()
8898
else()
8999
foreach(src ${op_library_SRCS})
90100
if(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$")
@@ -107,6 +117,8 @@ function(op_library TARGET)
107117
list(APPEND cu_cc_srcs ${src})
108118
elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$")
109119
list(APPEND xpu_cc_srcs ${src})
120+
elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
121+
list(APPEND npu_cc_srcs ${src})
110122
elseif(${src} MATCHES ".*\\.cc$")
111123
list(APPEND cc_srcs ${src})
112124
else()
@@ -176,7 +188,7 @@ function(op_library TARGET)
176188
# Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
177189
if(WITH_UNITY_BUILD AND op_library_UNITY)
178190
# Combine the cc source files.
179-
compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs})
191+
compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs})
180192
if(TARGET ${UNITY_TARGET})
181193
# If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
182194
target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources})
@@ -187,7 +199,7 @@ function(op_library TARGET)
187199
# Add alias library to handle dependencies.
188200
add_library(${TARGET} ALIAS ${UNITY_TARGET})
189201
else()
190-
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} DEPS ${op_library_DEPS}
202+
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} DEPS ${op_library_DEPS}
191203
${op_common_deps})
192204
endif()
193205
endif()
@@ -207,6 +219,7 @@ function(op_library TARGET)
207219
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
208220
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
209221
# And for detail pybind information, please see generated paddle/pybind/pybind.h.
222+
set(ORIGINAL_TARGET ${TARGET})
210223
file(READ ${TARGET}.cc TARGET_CONTENT)
211224
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
212225
# [ \t\r\n]* is used for blank characters
@@ -239,8 +252,9 @@ function(op_library TARGET)
239252
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
240253
list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
241254
list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
255+
list(LENGTH npu_cc_srcs npu_cc_srcs_len)
242256
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
243-
${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0)
257+
${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND ${npu_cc_srcs_len} EQUAL 0)
244258
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
245259
set(pybind_flag 1)
246260
endif()
@@ -280,6 +294,26 @@ function(op_library TARGET)
280294
if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0)
281295
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n")
282296
endif()
297+
298+
if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0)
299+
file(READ ${ORIGINAL_TARGET}_npu.cc TARGET_NPU_CONTENT)
300+
# It is different from the logic above, becareful
301+
string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\(.*" multi_npu_register "${TARGET_NPU_CONTENT}")
302+
# [ \t\r\n]* is used for blank characters
303+
string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_npu_register "${multi_npu_register}")
304+
305+
if (one_npu_register STREQUAL "")
306+
string(REPLACE "_op" "" NPU_TARGET "${TARGET}")
307+
else ()
308+
string(REPLACE "REGISTER_OP_NPU_KERNEL(" "" NPU_TARGET "${one_npu_register}")
309+
string(REPLACE "," "" NPU_TARGET "${NPU_TARGET}")
310+
# [ \t\r\n]+ is used for blank characters.
311+
# Here we use '+' instead of '*' since it is a REPLACE operation.
312+
string(REGEX REPLACE "[ \t\r\n]+" "" NPU_TARGET "${NPU_TARGET}")
313+
endif()
314+
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET}, NPU);\n")
315+
endif()
316+
283317
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
284318
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
285319
# Append first implemented MKLDNN activation operator
@@ -330,6 +364,7 @@ function(register_operators)
330364
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
331365
string(REPLACE "_mkldnn" "" OPS "${OPS}")
332366
string(REPLACE "_xpu" "" OPS "${OPS}")
367+
string(REPLACE "_npu" "" OPS "${OPS}")
333368
string(REPLACE ".cc" "" OPS "${OPS}")
334369
list(REMOVE_DUPLICATES OPS)
335370
list(LENGTH register_operators_DEPS register_operators_DEPS_len)

cmake/third_party.cmake

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,15 @@ if(WITH_BOX_PS)
274274
list(APPEND third_party_deps extern_box_ps)
275275
endif(WITH_BOX_PS)
276276

277-
if(WITH_ASCEND)
277+
if(WITH_ASCEND OR WITH_ASCEND_CL)
278278
include(external/ascend)
279-
list(APPEND third_party_deps extern_ascend)
280-
endif (WITH_ASCEND)
279+
if(WITH_ASCEND)
280+
list(APPEND third_party_deps extern_ascend)
281+
endif()
282+
if(WITH_ASCEND_CL)
283+
list(APPEND third_party_deps extern_ascend_cl)
284+
endif()
285+
endif ()
281286

282287
if (WITH_PSCORE)
283288
include(external/snappy)

paddle/fluid/framework/CMakeLists.txt

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -367,29 +367,23 @@ endif()
367367
##### 2.0 New custom op extension mechanism related #####
368368

369369
# if not deps `layer`, will cause: undefined symbol: _ZN6paddle10imperative7VarBase9name_set_
370-
set(PADDLE_CUSTOM_OP_MODULES custom_tensor op_meta_info custom_operator layer)
371-
372-
set(PADDLE_CUSTOM_OP_SRCS
373-
${CMAKE_CURRENT_SOURCE_DIR}/custom_operator.cc
374-
${CMAKE_CURRENT_SOURCE_DIR}/../extension/src/ext_tensor.cc
375-
${CMAKE_CURRENT_SOURCE_DIR}/../extension/src/ext_op_meta_info.cc
376-
${CMAKE_SOURCE_DIR}/paddle/fluid/imperative/layer.cc)
377-
set(PADDLE_CUSTOM_OP_SRCS ${PADDLE_CUSTOM_OP_SRCS} PARENT_SCOPE)
370+
if (WIN32)
371+
set(PADDLE_CUSTOM_OP_MODULES custom_tensor op_meta_info custom_operator layer)
378372

379-
cc_library(paddle_custom_op_shared
380-
SHARED SRCS ${PADDLE_CUSTOM_OP_SRCS} DEPS ${PADDLE_CUSTOM_OP_MODULES})
373+
set(PADDLE_CUSTOM_OP_SRCS
374+
${CMAKE_CURRENT_SOURCE_DIR}/custom_operator.cc
375+
${CMAKE_CURRENT_SOURCE_DIR}/../extension/src/ext_tensor.cc
376+
${CMAKE_CURRENT_SOURCE_DIR}/../extension/src/ext_op_meta_info.cc
377+
${CMAKE_SOURCE_DIR}/paddle/fluid/imperative/layer.cc)
378+
set(PADDLE_CUSTOM_OP_SRCS ${PADDLE_CUSTOM_OP_SRCS} PARENT_SCOPE)
381379

382-
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
383-
set_target_properties(paddle_custom_op_shared PROPERTIES OUTPUT_NAME paddle_custom_op)
384-
target_link_libraries(paddle_custom_op_shared ${os_dependency_modules})
380+
cc_library(paddle_custom_op_shared
381+
SHARED SRCS ${PADDLE_CUSTOM_OP_SRCS} DEPS ${PADDLE_CUSTOM_OP_MODULES})
385382

386-
if (LINUX)
387-
set(PADDLE_CUSTOM_OP_SHARED_LIB
388-
${PADDLE_BINARY_DIR}/paddle/fluid/framework/libpaddle_custom_op.so
389-
CACHE INTERNAL "Paddle custom op lib")
390-
endif()
383+
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
384+
set_target_properties(paddle_custom_op_shared PROPERTIES OUTPUT_NAME paddle_custom_op)
385+
target_link_libraries(paddle_custom_op_shared ${os_dependency_modules})
391386

392-
if (WIN32)
393387
if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
394388
set(paddle_custom_op_lib_path ${CMAKE_CURRENT_BINARY_DIR})
395389
else()
@@ -402,9 +396,3 @@ if (WIN32)
402396
${paddle_custom_op_lib_path}/paddle_custom_op.dll
403397
CACHE INTERNAL "Paddle custom op dll")
404398
endif()
405-
406-
if(APPLE)
407-
set(PADDLE_CUSTOM_OP_SHARED_LIB
408-
${PADDLE_BINARY_DIR}/paddle/fluid/framework/paddle_custom_op.dylib
409-
CACHE INTERNAL "Paddle custom op lib")
410-
endif()

paddle/fluid/framework/dlpack_tensor.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
8282
platform::errors::Unimplemented("platform::XPUPlace is not supported"));
8383
}
8484

85+
inline ::DLContext operator()(const platform::NPUPlace &place) const {
86+
PADDLE_THROW(
87+
platform::errors::Unimplemented("platform::NPUPlace is not supported"));
88+
}
89+
8590
inline ::DLContext operator()(const platform::CUDAPlace &place) const {
8691
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
8792
::DLContext ctx;

paddle/fluid/framework/executor.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,14 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
453453
#else
454454
PADDLE_THROW(
455455
platform::errors::Unimplemented("No XPU gc found in CPU/GPU paddle"));
456+
#endif
457+
} else if (platform::is_npu_place(place_)) {
458+
#ifdef PADDLE_WITH_ASCEND_CL
459+
// TODO(ascendrc): Support garbage collector on NPUPlace
460+
VLOG(4) << "Skip NPU gc because it is not implemented now.";
461+
#else
462+
PADDLE_THROW(platform::errors::Unimplemented(
463+
"No NPU gc found in CPU/GPU/XPU paddle"));
456464
#endif
457465
}
458466
}

paddle/fluid/framework/garbage_collector.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place,
8686
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamCreate(&stream_));
8787
#else
8888
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
89+
callback_manager_.reset(
90+
new platform::StreamCallbackManager<gpuStream_t>(stream_));
8991
#endif
90-
callback_manager_.reset(new platform::StreamCallbackManager(stream_));
9192
}
9293

9394
StreamGarbageCollector::~StreamGarbageCollector() {

0 commit comments

Comments
 (0)