Skip to content

Commit 35c1683

Browse files
authored
"refine kernel registrar" (#6998)
* "refine kernel registrar" * "refine registrar with multikey" * "fix register" * "refine multikernel register" * "fix CI" * "fix CI" * "fix registry" * "swtich GPU to CUDA" * "add register macro test case" * "fix CI"
1 parent 95862a5 commit 35c1683

6 files changed

Lines changed: 122 additions & 11 deletions

File tree

paddle/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
3737
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
3838

3939
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
40-
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
40+
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
4141

4242
py_proto_compile(framework_py_proto SRCS framework.proto)
4343
# Generate an empty __init__.py to make framework_py_proto as a valid python module.

paddle/framework/library_type.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ namespace framework {
2020
// For more details about the design of LibraryType, Please refer to
2121
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
2222

23-
enum class LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
23+
enum class LibraryType {
24+
kPlain = 0,
25+
kMKLDNN = 1,
26+
kCUDNN = 2,
27+
};
2428

2529
inline std::string LibraryTypeToString(const LibraryType& library_type) {
2630
switch (library_type) {
@@ -31,7 +35,26 @@ inline std::string LibraryTypeToString(const LibraryType& library_type) {
3135
case LibraryType::kCUDNN:
3236
return "CUDNN";
3337
default:
34-
PADDLE_THROW("unknown LibraryType %d", library_type);
38+
PADDLE_THROW("unknown LibraryType %d", static_cast<int>(library_type));
39+
}
40+
}
41+
42+
inline LibraryType StringToLibraryType(const char* ctype) {
43+
std::string s(ctype);
44+
if (s == std::string("PLAIN")) {
45+
return LibraryType::kPlain;
46+
} else if (s == std::string("MKLDNN")) {
47+
return LibraryType::kMKLDNN;
48+
} else if (s == std::string("CUDNN")) {
49+
return LibraryType::kCUDNN;
50+
// To be compatible with register macro.
51+
// CPU, CUDA, PLAIN are same library type.
52+
} else if (s == std::string("CPU")) {
53+
return LibraryType::kPlain;
54+
} else if (s == std::string("CUDA")) {
55+
return LibraryType::kPlain;
56+
} else {
57+
PADDLE_THROW("Unknown LibraryType %s", s.c_str());
3558
}
3659
}
3760

paddle/framework/op_kernel_type_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ TEST(OpKernelType, Hash) {
4848

4949
OpKernelType::Hash hasher;
5050
ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2));
51-
}
51+
}

paddle/framework/op_registry.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,30 +79,31 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
7979
using KERNEL_TYPE =
8080
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
8181

82-
void operator()(const char* op_type) const {
82+
void operator()(const char* op_type, const char* library_type) const {
8383
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
84-
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
84+
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
85+
DataLayout::kAnyLayout, StringToLibraryType(library_type));
8586
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
8687

8788
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
8889
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
8990
func;
90-
func(op_type);
91+
func(op_type, library_type);
9192
}
9293
};
9394

9495
template <typename PlaceType, size_t I, typename... KernelType>
9596
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
96-
void operator()(const char* op_type) const {}
97+
void operator()(const char* op_type, const char* library_type) const {}
9798
};
9899

99100
// User can register many kernel in one place. The data type could be different.
100101
template <typename PlaceType, typename... KernelType>
101102
class OpKernelRegistrar : public Registrar {
102103
public:
103-
explicit OpKernelRegistrar(const char* op_type) {
104+
explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
104105
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
105-
func(op_type);
106+
func(op_type, library_type);
106107
}
107108
};
108109

@@ -181,7 +182,8 @@ class OpKernelRegistrar : public Registrar {
181182
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
182183
"REGISTER_OP_KERNEL must be called in global namespace"); \
183184
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
184-
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
185+
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type, \
186+
#DEVICE_TYPE); \
185187
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
186188
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
187189
return 0; \

paddle/framework/op_registry_test.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
115
#include "paddle/framework/op_registry.h"
216
#include <gtest/gtest.h>
317

@@ -182,3 +196,71 @@ TEST(OperatorRegistrar, Test) {
182196
using namespace paddle::framework;
183197
OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos");
184198
}
199+
200+
namespace paddle {
201+
namespace framework {
202+
203+
class OpKernelTestMaker : public OpProtoAndCheckerMaker {
204+
public:
205+
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
206+
: OpProtoAndCheckerMaker(proto, op_checker) {
207+
AddComment("NoGradOp, same input output. no Grad");
208+
}
209+
};
210+
211+
class OpWithKernelTest : public OperatorWithKernel {
212+
public:
213+
using OperatorWithKernel::OperatorWithKernel;
214+
215+
protected:
216+
void InferShape(InferShapeContext* ctx) const override {}
217+
218+
framework::OpKernelType GetActualKernelType(
219+
const framework::ExecutionContext& ctx) const override {
220+
return framework::OpKernelType(proto::DataType::FP32, ctx.device_context());
221+
}
222+
};
223+
224+
template <typename DeviceContext, typename T>
225+
class OpKernelTest : public paddle::framework::OpKernel<T> {
226+
public:
227+
void Compute(const paddle::framework::ExecutionContext& ctx) const {}
228+
};
229+
230+
} // namespace framework
231+
} // namespace paddle
232+
233+
REGISTER_OP_WITHOUT_GRADIENT(op_with_kernel,
234+
paddle::framework::OpWithKernelTest,
235+
paddle::framework::OpKernelTestMaker);
236+
REGISTER_OP_CPU_KERNEL(
237+
op_with_kernel,
238+
paddle::framework::OpKernelTest<paddle::platform::CPUDeviceContext, float>);
239+
240+
REGISTER_OP_CUDA_KERNEL(op_with_kernel,
241+
paddle::framework::OpKernelTest<
242+
paddle::platform::CUDADeviceContext, float>);
243+
244+
TEST(OperatorRegistrar, CPU) {
245+
paddle::framework::proto::OpDesc op_desc;
246+
paddle::platform::CPUPlace cpu_place;
247+
paddle::framework::Scope scope;
248+
249+
op_desc.set_type("op_with_kernel");
250+
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
251+
252+
op->Run(scope, cpu_place);
253+
}
254+
255+
#ifdef PADDLE_WITH_CUDA
256+
TEST(OperatorRegistrar, CUDA) {
257+
paddle::framework::proto::OpDesc op_desc;
258+
paddle::platform::CUDAPlace cuda_place(0);
259+
paddle::framework::Scope scope;
260+
261+
op_desc.set_type("op_with_kernel");
262+
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
263+
264+
op->Run(scope, cuda_place);
265+
}
266+
#endif

paddle/operators/conv_cudnn_op.cu.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
315315
} // namespace operators
316316
} // namespace paddle
317317

318+
REGISTER_OP_KERNEL(conv2d, CUDNN, paddle::platform::CUDAPlace,
319+
paddle::operators::CudnnConvOpKernel<float>,
320+
paddle::operators::CudnnConvOpKernel<double>);
321+
318322
REGISTER_OP_CUDA_KERNEL(conv2d_cudnn,
319323
paddle::operators::CudnnConvOpKernel<float>,
320324
paddle::operators::CudnnConvOpKernel<double>);

0 commit comments

Comments
 (0)