Skip to content

Commit 9cb3aa8

Browse files
committed
simplify SelectKernelOrThrowError function in phi
1 parent 908a381 commit 9cb3aa8

File tree

2 files changed

+7
-22
lines changed

2 files changed

+7
-22
lines changed

paddle/phi/api/lib/kernel_dispatch.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ struct KernelKeySet {
5656

5757
// TODO(chenweihang): iterate all kernelkey for kernel selection
5858
phi::KernelKey GetHighestPriorityKernelKey() {
59-
return phi::KernelKey(static_cast<Backend>(32 - detail::CountLeadingZeros(
60-
backend_set.bitset())),
61-
layout,
62-
dtype);
59+
Backend backend_key = static_cast<Backend>(
60+
32 - detail::CountLeadingZeros(backend_set.bitset()));
61+
DataLayout layout_key = layout;
62+
if (backend_key != Backend::ONEDNN && layout_key != DataLayout::ONEDNN) {
63+
layout_key = DataLayout::ALL_LAYOUT;
64+
}
65+
return phi::KernelKey(backend_key, layout_key, dtype);
6366
}
6467
};
6568

paddle/phi/core/kernel_factory.cc

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
116116
if (use_gpudnn && kernel_key.backend() == Backend::GPU) {
117117
auto kernel_iter = iter->second.find(
118118
{Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()});
119-
if (kernel_iter == iter->second.end() &&
120-
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
121-
kernel_iter = iter->second.find(
122-
{Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
123-
}
124119
if (kernel_iter != iter->second.end()) {
125120
return {kernel_iter->second, false};
126121
}
@@ -129,13 +124,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
129124
}
130125
#endif
131126
auto kernel_iter = iter->second.find(kernel_key);
132-
// TODO(chenweihang): polish refind impl here
133-
if (kernel_iter == iter->second.end() &&
134-
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
135-
phi::KernelKey any_layout_kernel_key(
136-
kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
137-
kernel_iter = iter->second.find(any_layout_kernel_key);
138-
}
139127

140128
PADDLE_ENFORCE_NE(
141129
kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU,
@@ -157,12 +145,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
157145
phi::KernelKey cpu_kernel_key(
158146
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
159147
kernel_iter = iter->second.find(cpu_kernel_key);
160-
if (kernel_iter == iter->second.end() &&
161-
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
162-
phi::KernelKey any_layout_kernel_key(
163-
phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
164-
kernel_iter = iter->second.find(any_layout_kernel_key);
165-
}
166148

167149
PADDLE_ENFORCE_NE(
168150
kernel_iter,

0 commit comments

Comments
 (0)