Skip to content

Commit 1f33737

Browse files
committed
add xpu_support op function
*test=kunlun
1 parent 2af8219 commit 1f33737

File tree

12 files changed

+746
-469
lines changed

12 files changed

+746
-469
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,9 +1319,10 @@ bool OperatorWithKernel::SupportXPU() const {
13191319
op_kernels.end(),
13201320
[this](OpKernelMap::const_reference kern_pair) {
13211321
return platform::is_xpu_place(kern_pair.first.place_) &&
1322-
paddle::platform::is_xpu_support_op(type_,
1323-
kern_pair.first) &&
1324-
!paddle::platform::is_in_xpu_black_list(type_);
1322+
paddle::platform::is_xpu_support_op(
1323+
type_,
1324+
framework::TransToPhiDataType(
1325+
kern_pair.first.data_type_));
13251326
});
13261327
}
13271328
}
@@ -1409,16 +1410,17 @@ bool OperatorWithKernel::SupportsKernelType(
14091410
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
14101411
if (paddle::platform::is_xpu_place(kernel_type.place_)) {
14111412
return kernel_iter != kernels.end() &&
1412-
paddle::platform::is_xpu_support_op(type_, kernel_type) &&
1413-
!paddle::platform::is_in_xpu_black_list(type_);
1413+
paddle::platform::is_xpu_support_op(
1414+
type_, framework::TransToPhiDataType(kernel_type.data_type_));
14141415
}
14151416
#endif
14161417

14171418
#ifdef PADDLE_WITH_XPU_KP
14181419
if (paddle::platform::is_xpu_place(kernel_type.place_)) {
14191420
bool use_xpu_kp_kernel_rt =
14201421
FLAGS_run_kp_kernel &&
1421-
paddle::platform::is_xpu_kp_support_op(type_, kernel_type);
1422+
paddle::platform::is_xpu_support_op(
1423+
type_, framework::TransToPhiDataType(kernel_type.data_type_));
14221424
bool use_xpu_kp_kernel_debug =
14231425
paddle::platform::is_in_xpu_kpwhite_list(type_);
14241426
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
@@ -1428,8 +1430,8 @@ bool OperatorWithKernel::SupportsKernelType(
14281430
return kernels.find(tmp_kernel_type) != kernels.end();
14291431
}
14301432
return kernel_iter != kernels.end() &&
1431-
paddle::platform::is_xpu_support_op(type_, kernel_type) &&
1432-
!paddle::platform::is_in_xpu_black_list(type_);
1433+
paddle::platform::is_xpu_support_op(
1434+
type_, framework::TransToPhiDataType(kernel_type.data_type_));
14331435
}
14341436
#endif
14351437

@@ -1591,7 +1593,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
15911593
if (paddle::platform::is_xpu_place(kernel_type_->place_)) {
15921594
bool use_xpu_kp_kernel_rt =
15931595
FLAGS_run_kp_kernel &&
1594-
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
1596+
paddle::platform::is_xpu_support_op(
1597+
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
15951598
bool use_xpu_kp_kernel_debug =
15961599
paddle::platform::is_in_xpu_kpwhite_list(type_);
15971600
if (use_xpu_kp_kernel_rt) {
@@ -1668,7 +1671,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
16681671
if (paddle::platform::is_xpu_place(kernel_type_->place_)) {
16691672
bool use_xpu_kp_kernel_rt =
16701673
FLAGS_run_kp_kernel &&
1671-
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
1674+
paddle::platform::is_xpu_support_op(
1675+
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
16721676
bool use_xpu_kp_kernel_debug =
16731677
paddle::platform::is_in_xpu_kpwhite_list(type_);
16741678
if (use_xpu_kp_kernel_rt) {
@@ -1709,14 +1713,15 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
17091713
#if defined(PADDLE_WITH_XPU)
17101714
bool is_xpu_unsupport =
17111715
paddle::platform::is_xpu_place(kernel_type_->place_) &&
1712-
!paddle::platform::is_xpu_support_op(type_, *kernel_type_.get()) ||
1713-
paddle::platform::is_in_xpu_black_list(type_);
1716+
!paddle::platform::is_xpu_support_op(
1717+
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
17141718
#endif
17151719
#ifdef PADDLE_WITH_XPU_KP
17161720
bool use_xpu_kp_kernel_rt =
17171721
paddle::platform::is_xpu_place(kernel_type_->place_) &&
17181722
FLAGS_run_kp_kernel &&
1719-
paddle::platform::is_xpu_kp_support_op(type_, *kernel_type_);
1723+
paddle::platform::is_xpu_support_op(
1724+
type_, framework::TransToPhiDataType(kernel_type_->data_type_));
17201725
bool use_xpu_kp_kernel_debug =
17211726
paddle::platform::is_xpu_place(kernel_type_->place_) &&
17221727
paddle::platform::is_in_xpu_kpwhite_list(type_);
@@ -2051,8 +2056,9 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
20512056
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
20522057
if (platform::is_xpu_place(expected_kernel_key.place_) &&
20532058
(kernel_iter == kernels.end() ||
2054-
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
2055-
paddle::platform::is_in_xpu_black_list(type_))) {
2059+
!paddle::platform::is_xpu_support_op(
2060+
type_,
2061+
framework::TransToPhiDataType(expected_kernel_key.data_type_)))) {
20562062
VLOG(3) << "fluid missing XPU kernel: " << type_
20572063
<< ", expected_kernel_key:" << expected_kernel_key
20582064
<< ", fallbacking to CPU one!";
@@ -2065,7 +2071,9 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
20652071
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
20662072
bool use_xpu_kp_kernel_rt =
20672073
FLAGS_run_kp_kernel &&
2068-
paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key);
2074+
paddle::platform::is_xpu_support_op(
2075+
type_,
2076+
framework::TransToPhiDataType(expected_kernel_key.data_type_));
20692077
bool use_xpu_kp_kernel_debug =
20702078
paddle::platform::is_in_xpu_kpwhite_list(type_);
20712079
if (use_xpu_kp_kernel_rt) {
@@ -2093,9 +2101,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
20932101
<< ", using_kernel_key:" << expected_kernel_key;
20942102
}
20952103
}
2096-
bool is_xpu_unsupport =
2097-
(!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
2098-
paddle::platform::is_in_xpu_black_list(type_));
2104+
bool is_xpu_unsupport = (!paddle::platform::is_xpu_support_op(
2105+
type_, framework::TransToPhiDataType(expected_kernel_key.data_type_)));
20992106
if (!is_xpu_kp_support &&
21002107
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
21012108
VLOG(3) << "fluid missing XPU kernel: " << type_

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,9 @@ PreparedOp PrepareImpl(
255255
#if defined(PADDLE_WITH_XPU)
256256
bool is_xpu_unsupport =
257257
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
258-
!paddle::platform::is_xpu_support_op(op.Type(),
259-
expected_kernel_key) ||
260-
paddle::platform::is_in_xpu_black_list(op.Type());
258+
!paddle::platform::is_xpu_support_op(
259+
op.Type(),
260+
framework::TransToPhiDataType(expected_kernel_key.data_type_));
261261
#endif
262262

263263
#ifdef PADDLE_WITH_MLU
@@ -292,8 +292,10 @@ PreparedOp PrepareImpl(
292292
#ifdef PADDLE_WITH_XPU_KP
293293
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
294294
bool use_xpu_kp_kernel_rt =
295-
FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op(
296-
op.Type(), expected_kernel_key);
295+
FLAGS_run_kp_kernel &&
296+
paddle::platform::is_xpu_support_op(
297+
op.Type(),
298+
framework::TransToPhiDataType(expected_kernel_key.data_type_));
297299
bool use_xpu_kp_kernel_debug =
298300
paddle::platform::is_in_xpu_kpwhite_list(op.Type());
299301
if (use_xpu_kp_kernel_rt) {
@@ -368,7 +370,9 @@ PreparedOp PrepareImpl(
368370
bool use_xpu_kp_kernel_rt =
369371
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
370372
FLAGS_run_kp_kernel &&
371-
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
373+
paddle::platform::is_xpu_support_op(
374+
op.Type(),
375+
framework::TransToPhiDataType(expected_kernel_key.data_type_));
372376
bool use_xpu_kp_kernel_debug =
373377
paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
374378
paddle::platform::is_in_xpu_kpwhite_list(op.Type());

0 commit comments

Comments
 (0)