Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion paddle/framework/library_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,20 @@ inline std::string LibraryTypeToString(const LibraryType& library_type) {
case kCUDNN:
return "CUDNN";
default:
PADDLE_THROW("unknown LibraryType %d", library_type);
PADDLE_THROW("unknown LibraryType %d", static_cast<int>(library_type));
}
}

inline LibraryType StringToLibraryType(const char* ctype) {
std::string s(ctype);
if (s == std::string("PLAIN")) {
return LibraryType::kPlain;
} else if (s == std::string("MKLDNN")) {
return LibraryType::kMKLDNN;
} else if (s == std::string("CUDNN")) {
return LibraryType::kCUDNN;
} else {
PADDLE_THROW("Unknown LibraryType %s", s.c_str());
}
}

Expand Down
16 changes: 9 additions & 7 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,31 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;

void operator()(const char* op_type) const {
void operator()(const char* op_type, const char* library_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
DataLayout::kAnyLayout, StringToLibraryType(library_type));
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);

constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func;
func(op_type);
func(op_type, library_type);
}
};

template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type) const {}
void operator()(const char* op_type, const char* library_type) const {}
};

// User can register many kernel in one place. The data type could be different.
template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar {
public:
explicit OpKernelRegistrar(const char* op_type) {
explicit OpKernelRegistrar(const char* op_type, const char* library_type) {
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
func(op_type);
func(op_type, library_type);
}
};

Expand Down Expand Up @@ -181,7 +182,8 @@ class OpKernelRegistrar : public Registrar {
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type, \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DEVICE_TYPE here is actually LIBRARY_TYPE. We can fix it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is.

#DEVICE_TYPE); \
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
return 0; \
Expand Down
4 changes: 4 additions & 0 deletions paddle/operators/conv_cudnn_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle

REGISTER_OP_KERNEL(conv2d, CUDNN, paddle::platform::CUDAPlace,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);

REGISTER_OP_CUDA_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
Expand Down