Skip to content

Commit 09c5bc7

Browse files
vyasrabc99lr
authored andcommitted
Only use functions in the limited API (rapidsai#2282)
This PR removes usage of the only method in raft's Cython that is not part of the Python limited API. Contributes to rapidsai/build-planning#42 Authors: - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: rapidsai#2282
1 parent 723e19a commit 09c5bc7

File tree

2 files changed

+15
-33
lines changed

2 files changed

+15
-33
lines changed

python/pylibraft/pylibraft/common/mdspan.pyx

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import io
2222

2323
import numpy as np
2424

25+
from cpython.buffer cimport PyBUF_FULL_RO, PyBuffer_Release, PyObject_GetBuffer
2526
from cpython.object cimport PyObject
2627
from cython.operator cimport dereference as deref
2728
from libc.stddef cimport size_t
@@ -47,10 +48,6 @@ from pylibraft.common.optional cimport make_optional, optional
4748
from pylibraft.common import DeviceResources
4849

4950

50-
cdef extern from "Python.h":
51-
Py_buffer* PyMemoryView_GET_BUFFER(PyObject* mview)
52-
53-
5451
def run_roundtrip_test_for_mdspan(X, fortran_order=False):
5552
if not isinstance(X, np.ndarray) or len(X.shape) != 2:
5653
raise ValueError("Please call this function with a NumPy array with"
@@ -59,6 +56,9 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
5956
cdef device_resources * handle_ = \
6057
<device_resources *> <size_t> handle.getHandle()
6158
cdef ostringstream oss
59+
cdef Py_buffer buf
60+
PyObject_GetBuffer(X, &buf, PyBUF_FULL_RO)
61+
cdef uintptr_t buf_ptr = <uintptr_t>buf.buf
6262
if X.dtype == np.float32:
6363
if fortran_order:
6464
serialize_mdspan[float, matrix_extent[size_t], col_major](
@@ -67,8 +67,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
6767
<const host_mdspan[float, matrix_extent[size_t],
6868
col_major] &>
6969
make_host_matrix_view[float, size_t, col_major](
70-
<float *><uintptr_t>PyMemoryView_GET_BUFFER(
71-
<PyObject *> X.data).buf,
70+
<float *>buf_ptr,
7271
X.shape[0], X.shape[1]))
7372
else:
7473
serialize_mdspan[float, matrix_extent[size_t], row_major](
@@ -77,8 +76,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
7776
<const host_mdspan[float, matrix_extent[size_t],
7877
row_major]&>
7978
make_host_matrix_view[float, size_t, row_major](
80-
<float *><uintptr_t>PyMemoryView_GET_BUFFER(
81-
<PyObject *> X.data).buf,
79+
<float *>buf_ptr,
8280
X.shape[0], X.shape[1]))
8381
elif X.dtype == np.float64:
8482
if fortran_order:
@@ -88,8 +86,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
8886
<const host_mdspan[double, matrix_extent[size_t],
8987
col_major]&>
9088
make_host_matrix_view[double, size_t, col_major](
91-
<double *><uintptr_t>PyMemoryView_GET_BUFFER(
92-
<PyObject *> X.data).buf,
89+
<double *>buf_ptr,
9390
X.shape[0], X.shape[1]))
9491
else:
9592
serialize_mdspan[double, matrix_extent[size_t], row_major](
@@ -98,8 +95,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
9895
<const host_mdspan[double, matrix_extent[size_t],
9996
row_major]&>
10097
make_host_matrix_view[double, size_t, row_major](
101-
<double *><uintptr_t>PyMemoryView_GET_BUFFER(
102-
<PyObject *> X.data).buf,
98+
<double *>buf_ptr,
10399
X.shape[0], X.shape[1]))
104100
elif X.dtype == np.int32:
105101
if fortran_order:
@@ -109,8 +105,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
109105
<const host_mdspan[int32_t, matrix_extent[size_t],
110106
col_major]&>
111107
make_host_matrix_view[int32_t, size_t, col_major](
112-
<int32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
113-
<PyObject *> X.data).buf,
108+
<int32_t *>buf_ptr,
114109
X.shape[0], X.shape[1]))
115110
else:
116111
serialize_mdspan[int32_t, matrix_extent[size_t], row_major](
@@ -119,8 +114,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
119114
<const host_mdspan[int32_t, matrix_extent[size_t],
120115
row_major]&>
121116
make_host_matrix_view[int32_t, size_t, row_major](
122-
<int32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
123-
<PyObject *> X.data).buf,
117+
<int32_t *>buf_ptr,
124118
X.shape[0], X.shape[1]))
125119
elif X.dtype == np.uint32:
126120
if fortran_order:
@@ -130,8 +124,7 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
130124
<const host_mdspan[uint32_t, matrix_extent[size_t],
131125
col_major]&>
132126
make_host_matrix_view[uint32_t, size_t, col_major](
133-
<uint32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
134-
<PyObject *> X.data).buf,
127+
<uint32_t *>buf_ptr,
135128
X.shape[0], X.shape[1]))
136129
else:
137130
serialize_mdspan[uint32_t, matrix_extent[size_t], row_major](
@@ -140,11 +133,12 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False):
140133
<const host_mdspan[uint32_t, matrix_extent[size_t],
141134
row_major]&>
142135
make_host_matrix_view[uint32_t, size_t, row_major](
143-
<uint32_t *><uintptr_t>PyMemoryView_GET_BUFFER(
144-
<PyObject *> X.data).buf,
136+
<uint32_t *>buf_ptr,
145137
X.shape[0], X.shape[1]))
146138
else:
139+
PyBuffer_Release(&buf)
147140
raise NotImplementedError()
141+
PyBuffer_Release(&buf)
148142
f = io.BytesIO(oss.str())
149143
X2 = np.load(f)
150144
assert np.all(X.shape == X2.shape)

python/pylibraft/pylibraft/neighbors/cpp/hnsw.pxd

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,7 @@ cdef extern from "raft_runtime/neighbors/hnsw.hpp" \
7575
host_matrix_view[uint64_t, int64_t, row_major] neighbors,
7676
host_matrix_view[float, int64_t, row_major] distances) except +
7777

78-
cdef unique_ptr[index[float]] deserialize_file[float](
79-
const device_resources& handle,
80-
const string& filename,
81-
int dim,
82-
DistanceType metric) except +
83-
84-
cdef unique_ptr[index[int8_t]] deserialize_file[int8_t](
85-
const device_resources& handle,
86-
const string& filename,
87-
int dim,
88-
DistanceType metric) except +
89-
90-
cdef unique_ptr[index[uint8_t]] deserialize_file[uint8_t](
78+
cdef unique_ptr[index[T]] deserialize_file[T](
9179
const device_resources& handle,
9280
const string& filename,
9381
int dim,

0 commit comments

Comments
 (0)