Skip to content

Commit e6bc358

Browse files
frankwhzhangxymyeahzhiqiuliym27YuxiangLu
authored
【NPU】Cherry-pick ascendrc ops code by 0325 to develop (#32197)
* merge 31065 * Fix typo of selected_npus (#31230) * merge 31249 * [NPU] Support npu op pow and pow grad (#31247) * [NPU] Support npu op: (1) pow (2) pow_grad * Support fp16 * Fix pow npu fp16 test (#31256) * support list of list attribute for NPU (#31299) * support list of list attribute for NPU * fix compile problem * fix reference * [NPU] Support npu op: (1) slice (2) slice_grad (#31275) * fix reading flags from env (#31329) * merge 31347 * [NPU] Support npu op layer_norm and layer_norm_grad (#31310) * init commit, add layer_norm npu kernel * fix typo * add unittest * add unittest * fix bug * fix bug * refine ut * [NPU] add npu kernel for equal op (#31393) * add npu kernel for equal op * refine code * add more ut * update year * [NPU] Support npu kernel for shape op (#31427) * add shape npu * fix * fix * fix endif (#31431) * Fix pow, use fillD instead of broadcast (#31433) * Fix pow, refine code (#31440) * fix cmake of cryptopp to avoid downloading every time (#31451) * [NPU] squeeze and unsqueeze op for ascend (#31452) Co-authored-by: root <[email protected]> * Support npu kernel for gather op (#31458) * add gather npu op * code review done * update python new line * precommit * fix review * del commit * 【NPU】add scale op for npu (#31499) * add scale npu * fix * fix * Support TensorFormVector, TensorToVector of bool type (#31518) * support TensorFormVector, TensorToVector of bool type * add ut * fix compile problem * 【NPU】support npu kernel for fill_constant op (#31521) * add fill_constant npu * add fill_constant npu * fix * cherry-pick 31422, solve conflict * 【NPU】Support npu kernel for matmul op (#31544) * add matmulv2_npu * add matmul * add matmul * [NPU] Support npu op elementwise_mul and elementwise_mul_grad (#31571) * [NPU] Support npu op elementwise_max (#31574) * 【NPU】add relu op for npu (#31515) * add relu npu * fixed * fix * 【NPU】Suppert npu kernel for reshape2 op (#31524) * add reshape2 npu * add reshpe2 * [NPU] Support npu kernel for gather op fix bug (#31541) * add gather npu op * code review done * update python new line * precommit * fix review * del commit * update gather_grad * fix bug * fix bug * [NPU] Support npu kernel for amp_check_finite_and_unscale_npu op (#31457) * Support npu kernel for amp_check_finite_and_unscale_npu op * support EnforceNotMet exception * fix exception bug * modify python unittest * precommit * update c++ unittest * fix review * fix review * [NPU] accuracy op (#31492) * accuracy op * fix license * fix * add test and fix bug * [NPU] add Assign OP (#31561) * add assign op * add test assign npu test * dele if def Co-authored-by: oyjxer <[email protected]> * [NPU] fix npu op elementwise_mul_grad (#31592) * 【NPU】Support npu op gelu and gelu_grad (#31530) * Support npu op gelu and gelu_grad * Support npu op gelu and gelu_grad * [NPU] fix assgin cmake (#31595) * fix gather_grad bug (#31607) * [NPU] add range op (#31560) * add range op * fix codestyle; call GetSize directly Co-authored-by: oyjxer <[email protected]> * 【NPU】Support npu op elementwise_div and elementwise_div_grad (#31573) * Support npu op elementwise_div and elementwise_div_grad * Support npu op elementwise_div and elementwise_div_grad * Support npu op elementwise_div and elementwise_div_grad * [NPU] Support npu op log, log_grad, sqrt, sqrt_grad, square, tanh and tanh_grad (#31600) * [NPU] Support npu op logicalnot_op (#31534) * [NPU] Support npu op elementwise_min (#31575) * [NPU] Support npu op elementwise_pow (#31576) * [NPU] Support npu op table_lookup_v2 and table_lookup_v2_grad (#31399) * [npu] support npu kernel `table_lookup_v2` * clean up * +python test * +cmake * clean up * remove int8 kernel + python unitest for fp16 * clean up * [NPU] support npu kernel for `less_than` (#31327) * [npu] support npu kernel for `less than` * remove int* kernel * cleanup * [NPU] Support npu kernel scatter op (#31624) * Support npu kernel scatter op * Add more test * [NPU] fix allocator min chunk size (#31632) * [NPU] Support NPU kernel cast op (#31635) Co-authored-by: frankwhzhang <[email protected]> * [NPU] add npu kernel for sgd (#31639) * 【NPU】Support NPU kernel for reduce_sum op v2 (#31620) * add reduce_sum * fix broadcastd * fix test * fix * add unsqueeze in reduce_sum * add template * add unittest for keep_dim * test reduce_all Co-authored-by: frankwhzhang <[email protected]> * [NPU] add npu kernel for adam (#31644) * add npu kernel for adam * refine code * disable test * modify atol * 【NPU】Support npu kernel for mul op (#31584) * add mul * add test mul * [NPU] add npu kernel for softmax_with_cross_entropy (#31656) * init * fix bugs * [NPU] add npu kernel for mean Op (#31562) * update mean op * update mean op * give a better test activation Co-authored-by: oyjxer <[email protected]> * Revert "[NPU] add npu kernel for mean Op (#31562)" (#31665) This reverts commit 468ac69. * 【NPU】Add TensorCopy to NPU kernel for reduce_sum op (#31667) * update unittest * add TensorCopy in npu grad kernel * [NPU] Support npu op `expand` (#31405) * [npu] support npu kernel for `expand` * [NPU] fix shape of dx in mul_grad (#31675) * fix shape of dx * refine code * [NPU] add Increment op (#31563) * add increment * fix * update test increment op inplace * update increment op * increment b = 2 Co-authored-by: oyjxer <[email protected]> * [NPU] add NPU add topk (#31596) * add topk op * add cmake * update topk npu op * refactor func * fix test not go npu TopKD bug * NPUPlace(4) to NPUPlace(0) * update comment Co-authored-by: oyjxer <[email protected]> * [NPU] Support NPU kernel sum op (#31671) * [NPU] npu support `transpose` (#31486) * cherry-pick 31564, solve conflict * [NPU] Fix bug: Fix calculation errors of pow grad npu kernel (#31699) * [NPU] Support testing grad of NPU ops in OpTest (#31697) * [NPU] Support NPU kernel of stack op (#31711) * [NPU] Remove redundant ctest of top_k_op_npu_test (#31718) * [NPU] fix reshape npu op kernel (#31726) * rename npu op file * fix reshape * [NPU] change transpose to transpose2 (#31734) * change transpose to transpose2 * fix bug * [NPU] Support mean npu kernel (#31729) * [NPU] fix some bugs of npu op (#31739) * fix softmax * fix mean * fix lookup_table_v2 * 【NPU】Fix npu kernel elementwise_div_grad (#31753) * [NPU] fix the grad kernel diff bug of gather op (#31757) * fix gather grad kernel diff * fix gather grad kernel diff * fix gather review bug * 【NPU】Fix reshape test & add grad test (#31776) * fix * fix * [NPU] support fp16 for npu accuracy op (#31797) * [NPU] support list of tensor input (#31801) * support list of tensor as npu input * add comment * fix typo * fix typo * [NPU] add npu kernel for concat op (#31695) * add npu kernel for concat op * add npu kernel for concat op * refine code * update * refine concat_grad * [NPU] Support npu kernel for op elementwise_floordiv (#31822) * [NPU] fix bug of lookup_table_v2_grad (#31834) * [NPU] support default stream (#31510) * [NPU] support mixed precision input for npu layer norm (#31847) * support mixed precision input for npu layer norm * fix layer_norm npu kernel Co-authored-by: zhiqiu <[email protected]> * 【NPU】Support npu kernel for update_loss_scaling op (#31830) * add update_loss_scaling_npu NPU kernel * change TensorFromVec to Memset * fix compile problem (#31850) * [NPU] support npu for conditional_block op (#31854) * 【NPU】Add int dtype kernel for reshape2 op (#31864) * fix * fix * [NPU] fix some op bugs (#31855) * fix some op bugs * fix some bugs * follow comments * fix log level * add ut * [NPU] support fp16 of input for api pow (#31871) * [NPU] add npu kernel for truncated_gaussian_random op (#31654) * init * add todo * add npu kernel for truncated_gaussian_random * add sync * fix concat_grad * fix typo * fix compile * fix compile * fix compile * fix compile * fix compile * fix compile * fix code style * fix code style * fix code * Fix op test (#32231) * fix conditional block (#32243) * fix style code Co-authored-by: xiayanming <[email protected]> Co-authored-by: Leo Chen <[email protected]> Co-authored-by: liym27 <[email protected]> Co-authored-by: Reventon_L <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: oyjxer <[email protected]> Co-authored-by: yinhaofeng <[email protected]> Co-authored-by: OleNet <[email protected]> Co-authored-by: Meiyim <[email protected]> Co-authored-by: oyxuan-11 <[email protected]> Co-authored-by: pangyoki <[email protected]>
1 parent 69d8027 commit e6bc358

File tree

138 files changed

+13659
-314
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

138 files changed

+13659
-314
lines changed

cmake/external/gloo.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ cache_third_party(extern_gloo
3232
TAG ${GLOO_TAG}
3333
DIR GLOO_SOURCE_DIR)
3434

35-
if(WITH_ASCEND)
35+
if(WITH_ASCEND OR WITH_ASCEND_CL)
3636
ExternalProject_Add(
3737
extern_gloo
3838
${EXTERNAL_PROJECT_LOG_ARGS}

cmake/external/protobuf.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ endif()
242242
)
243243
ENDFUNCTION()
244244

245-
if(WITH_ASCEND)
245+
if(WITH_ASCEND OR WITH_ASCEND_CL)
246246
SET(PROTOBUF_VERSION 3.8.0)
247247
else()
248248
SET(PROTOBUF_VERSION 3.1.0)

cmake/external/threadpool.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ INCLUDE(ExternalProject)
1616

1717
SET(THREADPOOL_PREFIX_DIR ${THIRD_PARTY_PATH}/threadpool)
1818
SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool/src/extern_threadpool)
19-
if(WITH_ASCEND)
19+
if(WITH_ASCEND OR WITH_ASCEND_CL)
2020
SET(THREADPOOL_REPOSITORY https://gitee.com/tianjianhe/ThreadPool.git)
2121
else()
2222
SET(THREADPOOL_REPOSITORY ${GIT_URL}/progschj/ThreadPool.git)

cmake/external/warpctc.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ cache_third_party(extern_warpctc
4343
TAG ${WARPCTC_TAG}
4444
DIR WARPCTC_SOURCE_DIR)
4545

46-
if(WITH_ASCEND)
46+
if(WITH_ASCEND OR WITH_ASCEND_CL)
4747
ExternalProject_Add(
4848
extern_warpctc
4949
${EXTERNAL_PROJECT_LOG_ARGS}

paddle/fluid/framework/tensor_util.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ void TensorFromArray(const T* src, const size_t& array_size,
135135
}
136136
#endif
137137
}
138+
138139
template <typename T>
139140
void TensorFromVector(const std::vector<T>& src,
140141
const platform::DeviceContext& ctx, Tensor* dst) {
@@ -167,6 +168,49 @@ void TensorFromVector(const std::vector<T>& src,
167168
#endif
168169
}
169170

171+
// The fully specialized function should be inline to avoid
172+
// multi-definition.
173+
template <>
174+
inline void TensorFromVector(const std::vector<bool>& src,
175+
const platform::DeviceContext& ctx, Tensor* dst) {
176+
// vector<bool> has no data() member, use array instead.
177+
// See details:
178+
// https://stackoverflow.com/questions/46115669/why-does-stdvectorbool-have-no-data/46115714
179+
bool* array = new bool[src.size()];
180+
for (unsigned int i = 0; i < src.size(); i++) {
181+
array[i] = static_cast<bool>(src[i]);
182+
}
183+
184+
auto dst_place = ctx.GetPlace();
185+
auto src_ptr = static_cast<const void*>(array);
186+
platform::CPUPlace src_place;
187+
dst->Resize({static_cast<int64_t>(src.size())});
188+
auto dst_ptr = static_cast<void*>(dst->mutable_data<bool>(dst_place));
189+
auto size = src.size() * sizeof(bool);
190+
191+
if (platform::is_cpu_place(dst_place)) {
192+
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
193+
src_place, src_ptr, size);
194+
}
195+
#ifdef PADDLE_WITH_CUDA
196+
else if (platform::is_gpu_place(dst_place)) { // NOLINT
197+
memory::Copy(
198+
BOOST_GET_CONST(platform::CUDAPlace, dst_place), dst_ptr, src_place,
199+
src_ptr, size,
200+
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
201+
}
202+
#endif
203+
#ifdef PADDLE_WITH_ASCEND_CL
204+
else if (platform::is_npu_place(dst_place)) { // NOLINT
205+
memory::Copy(
206+
BOOST_GET_CONST(platform::NPUPlace, dst_place), dst_ptr, src_place,
207+
src_ptr, size,
208+
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
209+
}
210+
#endif
211+
delete[] array;
212+
}
213+
170214
template <typename T>
171215
void TensorFromVector(const std::vector<T>& src, Tensor* dst) {
172216
platform::CPUPlace dst_place = platform::CPUPlace();
@@ -179,6 +223,23 @@ void TensorFromVector(const std::vector<T>& src, Tensor* dst) {
179223
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
180224
}
181225

226+
template <>
227+
inline void TensorFromVector(const std::vector<bool>& src, Tensor* dst) {
228+
bool* array = new bool[src.size()];
229+
for (unsigned int i = 0; i < src.size(); i++) {
230+
array[i] = static_cast<bool>(src[i]);
231+
}
232+
platform::CPUPlace dst_place = platform::CPUPlace();
233+
auto src_ptr = static_cast<const void*>(array);
234+
platform::CPUPlace src_place;
235+
dst->Resize({static_cast<int64_t>(src.size())});
236+
auto dst_ptr = static_cast<void*>(dst->mutable_data<bool>(dst_place));
237+
auto size = src.size() * sizeof(bool);
238+
239+
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
240+
delete[] array;
241+
}
242+
182243
template <typename T>
183244
void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
184245
std::vector<T>* dst) {
@@ -212,6 +273,46 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
212273
#endif
213274
}
214275

276+
template <>
277+
inline void TensorToVector(const Tensor& src,
278+
const platform::DeviceContext& ctx,
279+
std::vector<bool>* dst) {
280+
auto src_ptr = static_cast<const void*>(src.data<bool>());
281+
auto size = src.numel() * sizeof(bool);
282+
283+
bool* array = new bool[src.numel()];
284+
285+
platform::CPUPlace dst_place;
286+
dst->resize(src.numel());
287+
auto dst_ptr = static_cast<void*>(array);
288+
289+
if (platform::is_cpu_place(src.place())) {
290+
memory::Copy(dst_place, dst_ptr,
291+
BOOST_GET_CONST(platform::CPUPlace, src.place()), src_ptr,
292+
size);
293+
}
294+
#ifdef PADDLE_WITH_CUDA
295+
else if (platform::is_gpu_place(src.place())) { // NOLINT
296+
memory::Copy(
297+
dst_place, dst_ptr, BOOST_GET_CONST(platform::CUDAPlace, src.place()),
298+
src_ptr, size,
299+
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
300+
}
301+
#endif
302+
#ifdef PADDLE_WITH_ASCEND_CL
303+
else if (platform::is_npu_place(src.place())) { // NOLINT
304+
memory::Copy(
305+
dst_place, dst_ptr, BOOST_GET_CONST(platform::NPUPlace, src.place()),
306+
src_ptr, size,
307+
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
308+
}
309+
#endif
310+
for (unsigned int i = 0; i < src.numel(); i++) {
311+
(*dst)[i] = static_cast<bool>(array[i]);
312+
}
313+
delete[] array;
314+
}
315+
215316
template <typename T>
216317
void TensorToVector(const Tensor& src, std::vector<T>* dst) {
217318
auto src_ptr = static_cast<const void*>(src.data<T>());
@@ -231,6 +332,32 @@ void TensorToVector(const Tensor& src, std::vector<T>* dst) {
231332
BOOST_GET_CONST(platform::CPUPlace, src.place()), src_ptr, size);
232333
}
233334

335+
template <>
336+
inline void TensorToVector(const Tensor& src, std::vector<bool>* dst) {
337+
auto src_ptr = static_cast<const void*>(src.data<bool>());
338+
auto size = src.numel() * sizeof(bool);
339+
340+
bool* array = new bool[src.numel()];
341+
342+
platform::CPUPlace dst_place;
343+
dst->resize(src.numel());
344+
auto dst_ptr = static_cast<void*>(array);
345+
346+
PADDLE_ENFORCE_EQ(
347+
platform::is_cpu_place(src.place()), true,
348+
platform::errors::InvalidArgument(
349+
"The input tensor should be CPU device, but actually it is in %s.",
350+
src.place()));
351+
352+
memory::Copy(dst_place, dst_ptr,
353+
BOOST_GET_CONST(platform::CPUPlace, src.place()), src_ptr, size);
354+
355+
for (unsigned int i = 0; i < src.numel(); i++) {
356+
(*dst)[i] = static_cast<bool>(array[i]);
357+
}
358+
delete[] array;
359+
}
360+
234361
std::ostream& operator<<(std::ostream& os, const Tensor& t);
235362
} // namespace framework
236363
} // namespace paddle

paddle/fluid/framework/tensor_util_test.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,61 @@ TEST(TensorToVector, Tensor) {
242242
#endif
243243
}
244244

245+
TEST(TensorToVector, Tensor_bool) {
246+
{
247+
paddle::framework::Tensor src;
248+
bool* src_ptr =
249+
src.mutable_data<bool>({3, 3}, paddle::platform::CPUPlace());
250+
for (int i = 0; i < 3 * 3; ++i) {
251+
src_ptr[i] = static_cast<bool>(i % 2);
252+
}
253+
254+
paddle::platform::CPUPlace place;
255+
std::vector<bool> dst;
256+
paddle::framework::TensorToVector<bool>(src, &dst);
257+
258+
for (int i = 0; i < 3 * 3; ++i) {
259+
EXPECT_EQ(src_ptr[i], dst[i]);
260+
}
261+
}
262+
#ifdef PADDLE_WITH_CUDA
263+
{
264+
std::vector<bool> src_vec = {
265+
false, true, false, true, false, true, false, true, false,
266+
};
267+
paddle::framework::Tensor gpu_tensor;
268+
paddle::platform::CUDAPlace place;
269+
paddle::platform::CUDADeviceContext gpu_ctx(place);
270+
paddle::framework::TensorFromVector<bool>(src_vec, gpu_ctx, &gpu_tensor);
271+
272+
std::vector<bool> dst;
273+
paddle::framework::TensorToVector<bool>(gpu_tensor, gpu_ctx, &dst);
274+
275+
for (int i = 0; i < 3 * 3; ++i) {
276+
EXPECT_EQ(src_vec[i], dst[i]);
277+
}
278+
}
279+
#endif
280+
#ifdef PADDLE_WITH_ASCEND_CL
281+
{
282+
std::vector<bool> src_vec = {
283+
false, true, false, true, false, true, false, true, false,
284+
};
285+
paddle::framework::Tensor npu_tensor;
286+
paddle::platform::NPUPlace place(0);
287+
paddle::platform::NPUDeviceContext npu_ctx(place);
288+
paddle::framework::TensorFromVector<bool>(src_vec, npu_ctx, &npu_tensor);
289+
290+
std::vector<bool> dst;
291+
paddle::framework::TensorToVector<bool>(npu_tensor, npu_ctx, &dst);
292+
293+
for (int i = 0; i < 3 * 3; ++i) {
294+
EXPECT_EQ(src_vec[i], dst[i]);
295+
}
296+
}
297+
#endif
298+
}
299+
245300
TEST(TensorFromDLPack, Tensor) {
246301
{
247302
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};

paddle/fluid/framework/type_defs.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ using Attribute = boost::variant<
4545

4646
using AttributeMap = std::unordered_map<std::string, Attribute>;
4747

48+
#ifdef PADDLE_WITH_ASCEND_CL
49+
using NPUAttribute =
50+
boost::variant<boost::blank, int, float, std::string, std::vector<int>,
51+
std::vector<float>, std::vector<std::string>, bool,
52+
std::vector<bool>, BlockDesc*, int64_t,
53+
std::vector<BlockDesc*>, std::vector<int64_t>,
54+
std::vector<double>, std::vector<std::vector<int64_t>>>;
55+
56+
using NPUAttributeMap = std::unordered_map<std::string, NPUAttribute>;
57+
#endif
58+
4859
using OpCreator = std::function<OperatorBase*(
4960
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
5061
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;

paddle/fluid/memory/memcpy.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,16 @@ void Copy<platform::NPUPlace, platform::CPUPlace>(platform::NPUPlace dst_place,
206206
if (UNLIKELY(num == 0)) return;
207207

208208
platform::SetNPUDeviceId(dst_place.device);
209+
210+
// NOTE(ascendrc): NPU memcpy async from host to device is a "real" async,
211+
// which is different from CUDA. In Paddle, when async is called, "sync"
212+
// is run actually, which means Paddle doesn't fully supported async.
213+
// TODO(ascendrc): Support NPU memcpy async for better performance.
214+
stream = nullptr;
215+
209216
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
210217
<< dst_place << " by thream(" << stream << ")";
218+
211219
if (stream) {
212220
platform::RecordEvent record_event("NpuMemcpyAsync:CPU->NPU");
213221
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_HOST_TO_DEVICE, stream);
@@ -226,8 +234,16 @@ void Copy<platform::CPUPlace, platform::NPUPlace>(platform::CPUPlace dst_place,
226234
if (UNLIKELY(num == 0)) return;
227235

228236
platform::SetNPUDeviceId(src_place.device);
237+
238+
// NOTE(ascendrc): NPU memcpy async from device to host is a "real" async,
239+
// which is different from CUDA. In Paddle, when async is called, "sync"
240+
// is run actually, which means Paddle doesn't fully supported async.
241+
// TODO(ascendrc): Support NPU memcpy async for better performance.
242+
stream = nullptr;
243+
229244
VLOG(4) << "memory::Copy " << num << " Bytes from " << src_place << " to "
230245
<< dst_place << " by thream(" << stream << ")";
246+
231247
if (stream) {
232248
platform::RecordEvent record_event("NpuMemcpyAsync:NPU->CPU");
233249
platform::NPUMemcpyAsync(dst, src, num, ACL_MEMCPY_DEVICE_TO_HOST, stream);

paddle/fluid/operators/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ if (WITH_ASCEND)
124124
endif()
125125

126126
if (WITH_ASCEND_CL)
127+
cc_test(assign_op_npu_test SRCS assign_op_npu_test.cc DEPS assign_op)
127128
cc_library(npu_op_runner SRCS npu_op_runner.cc DEPS operator npu_info)
128129
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner)
129130
endif()
@@ -141,8 +142,8 @@ set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS})
141142
set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
142143

143144
cc_test(test_common_infer_shape_functions SRCS test_common_infer_shape_functions.cc DEPS common_infer_shape_functions ${COMMON_OP_DEPS} activation_op elementwise_add_op softmax_op softmax)
144-
cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op)
145145
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
146+
cc_test(assign_op_test SRCS assign_op_test.cc DEPS assign_op)
146147
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function)
147148
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
148149
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
@@ -163,10 +164,19 @@ if (WITH_PYTHON)
163164
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
164165
endif()
165166

167+
if (WITH_ASCEND_CL)
168+
cc_test(range_op_npu_test SRCS range_op_npu_test.cc DEPS op_registry range_op scope device_context enforce executor)
169+
cc_test(lookup_table_v2_op_npu_test SRCS lookup_table_v2_op_npu_test.cc DEPS op_registry lookup_table_v2_op scope device_context enforce executor compare_op)
170+
endif()
171+
166172
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
167173
add_subdirectory(benchmark)
168174

169175
cc_test(op_debug_string_test SRCS op_debug_string_test.cc DEPS elementwise_add_op)
176+
if (WITH_ASCEND_CL)
177+
cc_test(transpose_op_npu_test SRCS transpose_op_npu_test.cc DEPS op_registry transpose_op scope device_context enforce executor)
178+
endif()
179+
170180

171181
if(WITH_MKLDNN)
172182
include(mkldnn/inplace_op_tests.cmake)
@@ -180,3 +190,7 @@ if(WITH_UNITY_BUILD)
180190
# The specified link dependency needs to be displayed here.
181191
target_link_libraries(paddle_operators_unity ${OP_HEADER_DEPS} ${COMMON_OP_DEPS})
182192
endif()
193+
194+
if(WITH_ASCEND_CL)
195+
cc_test(gelu_op_npu_test SRCS gelu_op_npu_test.cc DEPS op_registry gelu_op scope device_context enforce executor)
196+
endif()

0 commit comments

Comments
 (0)