Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
"murmurhash>=1.0.2,<1.1.0",
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"blis>=0.4.0,<0.8.0",
"blis>=0.9.0,<0.10.0",
"numpy>=1.15.0",
]
build-backend = "setuptools.build_meta"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
murmurhash>=1.0.2,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
blis>=0.4.0,<0.8.0
blis>=0.9.0,<0.10.0
srsly>=2.4.0,<3.0.0
wasabi>=0.8.1,<1.1.0
catalogue>=2.0.4,<2.1.0
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=1.0.2,<1.1.0
blis>=0.4.0,<0.8.0
blis>=0.9.0,<0.10.0
install_requires =
# Explosion-provided dependencies
blis>=0.4.0,<0.8.0
blis>=0.9.0,<0.10.0
murmurhash>=1.0.2,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

PACKAGES = find_packages()
MOD_NAMES = [
"thinc.backends.cblas",
"thinc.backends.linalg",
"thinc.backends.numpy_ops",
"thinc.extra.search",
Expand Down
24 changes: 24 additions & 0 deletions thinc/backends/cblas.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from libcpp.memory cimport shared_ptr


ctypedef void (*sgemm_ptr)(bint transA, bint transB, int M, int N, int K,
float alpha, const float* A, int lda, const float *B,
int ldb, float beta, float* C, int ldc) nogil


ctypedef void (*saxpy_ptr)(int N, float alpha, const float* X, int incX,
float *Y, int incY) nogil


# Forward-declaration of the BlasFuncs struct. This struct must be opaque, so
# that consumers of the CBlas class cannot become dependent on its size or
# ordering.
cdef struct BlasFuncs


cdef class CBlas:
cdef shared_ptr[BlasFuncs] ptr
cdef saxpy_ptr saxpy(self) nogil
cdef sgemm_ptr sgemm(self) nogil
cdef void set_saxpy(self, saxpy_ptr saxpy) nogil
cdef void set_sgemm(self, sgemm_ptr sgemm) nogil
32 changes: 32 additions & 0 deletions thinc/backends/cblas.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
cimport blis.cy
from cython.operator cimport dereference as deref
from libcpp.memory cimport make_shared


cdef struct BlasFuncs:
saxpy_ptr saxpy
sgemm_ptr sgemm


cdef class CBlas:
__slots__ = []

def __init__(self):
"""Construct a CBlas instance set to use BLIS implementations of the
supported BLAS functions."""
cdef BlasFuncs funcs
funcs.saxpy = blis.cy.saxpy
funcs.sgemm = blis.cy.sgemm
self.ptr = make_shared[BlasFuncs](funcs)

cdef saxpy_ptr saxpy(self) nogil:
return deref(self.ptr).saxpy

cdef sgemm_ptr sgemm(self) nogil:
return deref(self.ptr).sgemm

cdef void set_saxpy(self, saxpy_ptr saxpy) nogil:
deref(self.ptr).saxpy = saxpy

cdef void set_sgemm(self, sgemm_ptr sgemm) nogil:
deref(self.ptr).sgemm = sgemm
7 changes: 7 additions & 0 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cimport blis.cy
from .. import registry
from ..util import copy_array, get_array_module
from ..types import DeviceTypes, DTypes, Shape, ArrayXd
from .cblas cimport CBlas
from .linalg cimport VecVec, Vec
from .ops import Ops

Expand All @@ -30,6 +31,9 @@ except ImportError:
has_blis = False


cblas = CBlas()


ctypedef float weight_t


Expand Down Expand Up @@ -82,6 +86,9 @@ class NumpyOps(Ops):
else:
return self.xp.empty(shape, dtype=dtype)

def cblas(self) -> CBlas:
return cblas

def gemm(self, np.ndarray x, np.ndarray y, *, np.ndarray out=None, trans1=False, trans2=False):
if x.ndim != 2:
raise ValueError(f"Provided 'x' array should be 2-dimensional, but found {x.ndim} dimension(s).")
Expand Down
6 changes: 6 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..types import DeviceTypes, Generator, Padded, Batchable, SizedGenerator
from ..util import get_array_module, is_xp_array, to_numpy

from .cblas import CBlas

ArrayT = TypeVar("ArrayT", bound=ArrayXd)
FloatsT = TypeVar("FloatsT", bound=_Floats)
Expand All @@ -31,6 +32,11 @@ def __init__(
self.device_type = device_type
self.device_id = device_id

def cblas(self) -> CBlas:
"""Return C BLAS function table."""
err = f"{type(self).__name__} does not provide C BLAS functions"
raise NotImplementedError(err)

def to_numpy(self, data, *, byte_order=None): # pragma: no cover
if isinstance(data, numpy.ndarray):
if byte_order:
Expand Down
20 changes: 20 additions & 0 deletions website/docs/api-backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,26 @@ the inputs and outputs.
| `zeros` | <tt>bool</tt> | Fill the array with zeros (default: `True`). |
| **RETURNS** | <tt>ArrayXd</tt> | An array of the correct shape and data type. |

### Ops.cblas {#cblas tag="method"}

<inline-list>

- **default:** <i name="no"></i>
- **numpy:** <i name="yes"></i>
- **cupy:** <i name="no"></i>

</inline-list>

Get a table of C BLAS functions usable in Cython `cdef nogil` functions. This
method does not take any arguments.

<infobox variant="warning">

This method is only supported by `NumpyOps`. A `NotImplementedError` exception
is raised when calling this method on `Ops` or `CupyOps`.

</infobox>

### Ops.to_numpy {#to_numpy tag="method"}

<inline-list>
Expand Down