Skip to content

Commit ff8b2cb

Browse files
authored
[Kernel Selection] Simplify kernel selection process in phi, reduce search number to half (#47771)
* simplify SelectKernelOrThrowError function in phi * opt kernel_selection process * polish code, fix backend error
1 parent c2e77ba commit ff8b2cb

File tree

1 file changed

+4
-20
lines changed

1 file changed

+4
-20
lines changed

paddle/phi/core/kernel_factory.cc

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,13 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
114114
kernels_.end(),
115115
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
116116

117-
KernelKey kernel_key = const_kernel_key;
117+
KernelKey kernel_key = KernelKey(const_kernel_key.backend(),
118+
phi::DataLayout::ALL_LAYOUT,
119+
const_kernel_key.dtype());
118120
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
119121
if (kernel_key.backend() == Backend::GPUDNN) {
120122
auto kernel_iter = iter->second.find(
121-
{Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()});
122-
if (kernel_iter == iter->second.end() &&
123-
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
124-
kernel_iter = iter->second.find(
125-
{Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
126-
}
123+
{Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()});
127124
if (kernel_iter != iter->second.end()) {
128125
return {kernel_iter->second, false};
129126
}
@@ -132,13 +129,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
132129
}
133130
#endif
134131
auto kernel_iter = iter->second.find(kernel_key);
135-
// TODO(chenweihang): polish refind impl here
136-
if (kernel_iter == iter->second.end() &&
137-
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
138-
phi::KernelKey any_layout_kernel_key(
139-
kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
140-
kernel_iter = iter->second.find(any_layout_kernel_key);
141-
}
142132

143133
PADDLE_ENFORCE_NE(
144134
kernel_iter == iter->second.end() && kernel_key.backend() == Backend::CPU,
@@ -162,12 +152,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
162152
phi::KernelKey cpu_kernel_key(
163153
phi::Backend::CPU, kernel_key.layout(), kernel_key.dtype());
164154
kernel_iter = iter->second.find(cpu_kernel_key);
165-
if (kernel_iter == iter->second.end() &&
166-
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
167-
phi::KernelKey any_layout_kernel_key(
168-
phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
169-
kernel_iter = iter->second.find(any_layout_kernel_key);
170-
}
171155

172156
PADDLE_ENFORCE_NE(
173157
kernel_iter,

0 commit comments

Comments
 (0)