@@ -22,6 +22,7 @@ import io
2222
2323import numpy as np
2424
25+ from cpython.buffer cimport PyBUF_FULL_RO, PyBuffer_Release, PyObject_GetBuffer
2526from cpython.object cimport PyObject
2627from cython.operator cimport dereference as deref
2728from libc.stddef cimport size_t
@@ -47,10 +48,6 @@ from pylibraft.common.optional cimport make_optional, optional
4748from pylibraft.common import DeviceResources
4849
4950
50- cdef extern from " Python.h" :
51- Py_buffer* PyMemoryView_GET_BUFFER(PyObject* mview)
52-
53-
5451def run_roundtrip_test_for_mdspan (X , fortran_order = False ):
5552 if not isinstance (X, np.ndarray) or len (X.shape) != 2 :
5653 raise ValueError (" Please call this function with a NumPy array with"
@@ -59,6 +56,9 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
5956 cdef device_resources * handle_ = \
6057 < device_resources * > < size_t> handle.getHandle()
6158 cdef ostringstream oss
59+ cdef Py_buffer buf
60+ PyObject_GetBuffer(X, & buf, PyBUF_FULL_RO)
61+ cdef uintptr_t buf_ptr = < uintptr_t> buf.buf
6262 if X.dtype == np.float32:
6363 if fortran_order:
6464 serialize_mdspan[float , matrix_extent[size_t], col_major](
@@ -67,8 +67,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
6767 < const host_mdspan[float , matrix_extent[size_t],
6868 col_major] & >
6969 make_host_matrix_view[float , size_t, col_major](
70- < float * >< uintptr_t> PyMemoryView_GET_BUFFER(
71- < PyObject * > X.data).buf,
70+ < float * > buf_ptr,
7271 X.shape[0 ], X.shape[1 ]))
7372 else :
7473 serialize_mdspan[float , matrix_extent[size_t], row_major](
@@ -77,8 +76,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
7776 < const host_mdspan[float , matrix_extent[size_t],
7877 row_major]& >
7978 make_host_matrix_view[float , size_t, row_major](
80- < float * >< uintptr_t> PyMemoryView_GET_BUFFER(
81- < PyObject * > X.data).buf,
79+ < float * > buf_ptr,
8280 X.shape[0 ], X.shape[1 ]))
8381 elif X.dtype == np.float64:
8482 if fortran_order:
@@ -88,8 +86,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
8886 < const host_mdspan[double , matrix_extent[size_t],
8987 col_major]& >
9088 make_host_matrix_view[double , size_t, col_major](
91- < double * >< uintptr_t> PyMemoryView_GET_BUFFER(
92- < PyObject * > X.data).buf,
89+ < double * > buf_ptr,
9390 X.shape[0 ], X.shape[1 ]))
9491 else :
9592 serialize_mdspan[double , matrix_extent[size_t], row_major](
@@ -98,8 +95,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
9895 < const host_mdspan[double , matrix_extent[size_t],
9996 row_major]& >
10097 make_host_matrix_view[double , size_t, row_major](
101- < double * >< uintptr_t> PyMemoryView_GET_BUFFER(
102- < PyObject * > X.data).buf,
98+ < double * > buf_ptr,
10399 X.shape[0 ], X.shape[1 ]))
104100 elif X.dtype == np.int32:
105101 if fortran_order:
@@ -109,8 +105,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
109105 < const host_mdspan[int32_t, matrix_extent[size_t],
110106 col_major]& >
111107 make_host_matrix_view[int32_t, size_t, col_major](
112- < int32_t * >< uintptr_t> PyMemoryView_GET_BUFFER(
113- < PyObject * > X.data).buf,
108+ < int32_t * > buf_ptr,
114109 X.shape[0 ], X.shape[1 ]))
115110 else :
116111 serialize_mdspan[int32_t, matrix_extent[size_t], row_major](
@@ -119,8 +114,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
119114 < const host_mdspan[int32_t, matrix_extent[size_t],
120115 row_major]& >
121116 make_host_matrix_view[int32_t, size_t, row_major](
122- < int32_t * >< uintptr_t> PyMemoryView_GET_BUFFER(
123- < PyObject * > X.data).buf,
117+ < int32_t * > buf_ptr,
124118 X.shape[0 ], X.shape[1 ]))
125119 elif X.dtype == np.uint32:
126120 if fortran_order:
@@ -130,8 +124,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
130124 < const host_mdspan[uint32_t, matrix_extent[size_t],
131125 col_major]& >
132126 make_host_matrix_view[uint32_t, size_t, col_major](
133- < uint32_t * >< uintptr_t> PyMemoryView_GET_BUFFER(
134- < PyObject * > X.data).buf,
127+ < uint32_t * > buf_ptr,
135128 X.shape[0 ], X.shape[1 ]))
136129 else :
137130 serialize_mdspan[uint32_t, matrix_extent[size_t], row_major](
@@ -140,11 +133,12 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
140133 < const host_mdspan[uint32_t, matrix_extent[size_t],
141134 row_major]& >
142135 make_host_matrix_view[uint32_t, size_t, row_major](
143- < uint32_t * >< uintptr_t> PyMemoryView_GET_BUFFER(
144- < PyObject * > X.data).buf,
136+ < uint32_t * > buf_ptr,
145137 X.shape[0 ], X.shape[1 ]))
146138 else :
139+ PyBuffer_Release(& buf)
147140 raise NotImplementedError ()
141+ PyBuffer_Release(& buf)
148142 f = io.BytesIO(oss.str())
149143 X2 = np.load(f)
150144 assert np.all(X.shape == X2.shape)
0 commit comments