3434
3535namespace py = pybind11;
3636
37+ namespace pybind11 {
38+ namespace detail {
39+
40+ // Note: use same enum number of float16 in numpy.
41+ // import numpy as np
42+ // print np.dtype(np.float16).num # 23
43+ constexpr int NPY_FLOAT16_ = 23 ;
44+ constexpr int NPY_UINT16_ = 4 ;
45+
46+ // Note: Since float16 is not a builtin type in C++, we register
47+ // paddle::platform::float16 as numpy.float16.
48+ // Ref: https://github.com/pybind/pybind11/issues/1776
49+ template <>
50+ struct npy_format_descriptor <paddle_infer::float16> {
51+ static py::dtype dtype () {
52+ handle ptr = npy_api::get ().PyArray_DescrFromType_ (NPY_FLOAT16_);
53+ return reinterpret_borrow<py::dtype>(ptr);
54+ }
55+ static std::string format () {
56+ // Note: "e" represents float16.
57+ // Details at:
58+ // https://docs.python.org/3/library/struct.html#format-characters.
59+ return " e" ;
60+ }
61+ static constexpr auto name = _(" float16" );
62+ };
63+
64+ } // namespace detail
65+ } // namespace pybind11
66+
3767namespace paddle {
3868namespace pybind {
3969using paddle::AnalysisPredictor;
@@ -126,6 +156,9 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
126156 case PaddleDType::UINT8:
127157 dt = py::dtype::of<uint8_t >();
128158 break ;
159+ case PaddleDType::FLOAT16:
160+ dt = py::dtype::of<paddle_infer::float16>();
161+ break ;
129162 default :
130163 PADDLE_THROW (platform::errors::Unimplemented (
131164 " Unsupported data type. Now only supports INT32, INT64, UINT8 and "
@@ -196,6 +229,10 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
196229 case PaddleDType::FLOAT32:
197230 tensor.copy_to_cpu <float >(static_cast <float *>(array.mutable_data ()));
198231 break ;
232+ case PaddleDType::FLOAT16:
233+ tensor.copy_to_cpu <paddle::platform::float16>(
234+ static_cast <paddle::platform::float16 *>(array.mutable_data ()));
235+ break ;
199236 case PaddleDType::UINT8:
200237 tensor.copy_to_cpu <uint8_t >(static_cast <uint8_t *>(array.mutable_data ()));
201238 break ;
@@ -226,6 +263,10 @@ py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT
226263 case PaddleDType::FLOAT32:
227264 tensor.CopyToCpu <float >(static_cast <float *>(array.mutable_data ()));
228265 break ;
266+ case PaddleDType::FLOAT16:
267+ tensor.CopyToCpu <paddle::platform::float16>(
268+ static_cast <paddle::platform::float16 *>(array.mutable_data ()));
269+ break ;
229270 case PaddleDType::UINT8:
230271 tensor.CopyToCpu (static_cast <uint8_t *>(array.mutable_data ()));
231272 break ;
@@ -642,6 +683,7 @@ void BindZeroCopyTensor(py::module *m) {
642683 .def (" copy_from_cpu" , &ZeroCopyTensorCreate<int32_t >)
643684 .def (" copy_from_cpu" , &ZeroCopyTensorCreate<int64_t >)
644685 .def (" copy_from_cpu" , &ZeroCopyTensorCreate<float >)
686+ .def (" copy_from_cpu" , &ZeroCopyTensorCreate<paddle_infer::float16>)
645687 .def (" copy_to_cpu" , &ZeroCopyTensorToNumpy)
646688 .def (" shape" , &ZeroCopyTensor::shape)
647689 .def (" set_lod" , &ZeroCopyTensor::SetLoD)
@@ -655,6 +697,7 @@ void BindPaddleInferTensor(py::module *m) {
655697 .def (" copy_from_cpu" , &PaddleInferTensorCreate<int32_t >)
656698 .def (" copy_from_cpu" , &PaddleInferTensorCreate<int64_t >)
657699 .def (" copy_from_cpu" , &PaddleInferTensorCreate<float >)
700+ .def (" copy_from_cpu" , &PaddleInferTensorCreate<paddle_infer::float16>)
658701 .def (" copy_to_cpu" , &PaddleInferTensorToNumpy)
659702 .def (" shape" , &paddle_infer::Tensor::shape)
660703 .def (" set_lod" , &paddle_infer::Tensor::SetLoD)
0 commit comments