@@ -114,16 +114,13 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
114114 kernels_.end (),
115115 phi::errors::NotFound (" The kernel `%s` is not registered." , kernel_name));
116116
117- KernelKey kernel_key = const_kernel_key;
117+ KernelKey kernel_key = KernelKey (const_kernel_key.backend (),
118+ phi::DataLayout::ALL_LAYOUT,
119+ const_kernel_key.dtype ());
118120#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
119121 if (kernel_key.backend () == Backend::GPUDNN) {
120122 auto kernel_iter = iter->second .find (
121- {Backend::GPUDNN, kernel_key.layout (), kernel_key.dtype ()});
122- if (kernel_iter == iter->second .end () &&
123- kernel_key.layout () != phi::DataLayout::ALL_LAYOUT) {
124- kernel_iter = iter->second .find (
125- {Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype ()});
126- }
123+ {Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype ()});
127124 if (kernel_iter != iter->second .end ()) {
128125 return {kernel_iter->second , false };
129126 }
@@ -132,13 +129,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
132129 }
133130#endif
134131 auto kernel_iter = iter->second .find (kernel_key);
135- // TODO(chenweihang): polish refind impl here
136- if (kernel_iter == iter->second .end () &&
137- kernel_key.layout () != phi::DataLayout::ALL_LAYOUT) {
138- phi::KernelKey any_layout_kernel_key (
139- kernel_key.backend (), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype ());
140- kernel_iter = iter->second .find (any_layout_kernel_key);
141- }
142132
143133 PADDLE_ENFORCE_NE (
144134 kernel_iter == iter->second .end () && kernel_key.backend () == Backend::CPU,
@@ -162,12 +152,6 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
162152 phi::KernelKey cpu_kernel_key (
163153 phi::Backend::CPU, kernel_key.layout (), kernel_key.dtype ());
164154 kernel_iter = iter->second .find (cpu_kernel_key);
165- if (kernel_iter == iter->second .end () &&
166- kernel_key.layout () != phi::DataLayout::ALL_LAYOUT) {
167- phi::KernelKey any_layout_kernel_key (
168- phi::Backend::CPU, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype ());
169- kernel_iter = iter->second .find (any_layout_kernel_key);
170- }
171155
172156 PADDLE_ENFORCE_NE (
173157 kernel_iter,
0 commit comments