From 005901ebcd62d3ad07c59f97c3ad87f4a3ff40ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 2 Sep 2022 09:58:06 +0200 Subject: [PATCH 1/8] Introduce `SparseLinear_v2` to fix indexing issues `SparseLinear` does not correctly index the gradient/weight matrix (#752). This change fixes the indexing, so that the full matrix is used. To retain compatibility with existing models that use `SparseLinear`, which works relatively well if there are not too many hash collisions, the fixed version is renamed to `SparseLinear_v2`. Thanks to @sriram7797 for reporting this issue! --- thinc/api.py | 2 +- thinc/layers/__init__.py | 2 +- thinc/layers/sparselinear.pyx | 47 ++++++++++++++++++++------- thinc/tests/layers/test_layers_api.py | 1 + website/docs/api-layers.md | 36 ++++++++++++++++++++ 5 files changed, 75 insertions(+), 13 deletions(-) diff --git a/thinc/api.py b/thinc/api.py index 8c5807347..c2509eb3f 100644 --- a/thinc/api.py +++ b/thinc/api.py @@ -25,7 +25,7 @@ from .layers import Dropout, Embed, expand_window, HashEmbed, LayerNorm, Linear from .layers import Maxout, Mish, MultiSoftmax, Relu, softmax_activation, Softmax, LSTM from .layers import CauchySimilarity, ParametricAttention, Logistic -from .layers import resizable, sigmoid_activation, Sigmoid, SparseLinear +from .layers import resizable, sigmoid_activation, Sigmoid, SparseLinear, SparseLinear_v2 from .layers import ClippedLinear, ReluK, HardTanh, HardSigmoid from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM diff --git a/thinc/layers/__init__.py b/thinc/layers/__init__.py index 19fee1329..7f14228db 100644 --- a/thinc/layers/__init__.py +++ b/thinc/layers/__init__.py @@ -26,7 +26,7 @@ from .sigmoid import Sigmoid from .softmax_activation import softmax_activation from .softmax import Softmax, Softmax_v2 -from .sparselinear import SparseLinear +from .sparselinear import SparseLinear, SparseLinear_v2 from .tensorflowwrapper import TensorFlowWrapper, keras_subclass from .mxnetwrapper import MXNetWrapper diff --git a/thinc/layers/sparselinear.pyx b/thinc/layers/sparselinear.pyx index 92c324bd5..ecc057b60 100644 --- a/thinc/layers/sparselinear.pyx +++ b/thinc/layers/sparselinear.pyx @@ -27,6 +27,22 @@ def SparseLinear(nO: Optional[int] = None, length: int = 2 ** 18): init=init, params={"W": None, "b": None}, dims={"nO": nO, "length": length}, + attrs={"invalid_indexing": True}, + ) + + +@cython.binding(True) +@registry.layers("SparseLinear.v2") +def SparseLinear_v2(nO: Optional[int] = None, length: int = 2 ** 18): + # NB: We can't have generic return type annotation if we want function to + # be bound (and inspectable): https://github.com/cython/cython/issues/2753 + return Model( + "sparse_linear", + forward, + init=init, + params={"W": None, "b": None}, + dims={"nO": nO, "length": length}, + attrs={"invalid_indexing": False}, ) @@ -70,11 +86,12 @@ def _begin_cpu_update(model, np.ndarray keys, np.ndarray values, np.ndarray leng cdef np.ndarray W = model.get_param("W") cdef np.ndarray b = model.get_param("b") cdef np.ndarray scores = model.ops.alloc((len(lengths), nO)) + cdef bint invalid_indexing = model.attrs["invalid_indexing"] scores += b set_scoresC(scores.data, keys.data, values.data, lengths.data, lengths.shape[0], nO, - W.data, length) + W.data, length, invalid_indexing) return scores, _finish_linear_update(model, keys, values, lengths) @@ -95,10 +112,10 @@ class _finish_linear_update: cdef np.ndarray keys = self.keys cdef np.ndarray values = self.values cdef np.ndarray lengths = self.lengths + cdef bint invalid_indexing = self.model.attrs["invalid_indexing"] set_gradientC(d_weights.data, keys.data, values.data, lengths.data, - lengths.shape[0], nO, - &d_scores[0,0], length) + lengths.shape[0], nO, &d_scores[0,0], length, invalid_indexing) cdef int i, j for i in range(d_scores.shape[0]): for j in range(d_scores.shape[1]): @@ -110,8 +127,8 @@ class _finish_linear_update: cdef void set_scoresC(float* scores, const uint64_t* keys, const float* values, const int32_t* lengths, - int batch_size, int nr_out, - const float* weights, int nr_weight) nogil: + int batch_size, int nr_out, const float* weights, int nr_weight, + bint invalid_indexing) nogil: cdef uint32_t idx1, idx2 cdef uint32_t hash1, hash2 for length in lengths[:batch_size]: @@ -122,8 +139,12 @@ cdef void set_scoresC(float* scores, idx2 = hash2 & (nr_weight-1) value = values[i] for clas in range(nr_out): - scores[clas] += weights[idx1 + clas] * value - scores[clas] += weights[idx2 + clas] * value + if invalid_indexing: + scores[clas] += weights[idx1 + clas] * value + scores[clas] += weights[idx2 + clas] * value + else: + scores[clas] += weights[(clas * nr_weight) + idx1] * value + scores[clas] += weights[(clas * nr_weight) + idx2] * value scores += nr_out keys += length values += length @@ -131,8 +152,8 @@ cdef void set_scoresC(float* scores, cdef void set_gradientC(float* d_weights, const uint64_t* keys, const float* values, const int32_t* lengths, - int batch_size, int nr_out, - const float* d_scores, int nr_weight) nogil: + int batch_size, int nr_out, const float* d_scores, int nr_weight, + bint invalid_indexing) nogil: cdef uint32_t idx1, idx2 cdef uint32_t hash1, hash2 for length in lengths[:batch_size]: @@ -143,8 +164,12 @@ cdef void set_gradientC(float* d_weights, idx2 = hash2 & (nr_weight-1) value = values[i] for clas in range(nr_out): - d_weights[idx1 + clas] += d_scores[clas] * value - d_weights[idx2 + clas] += d_scores[clas] * value + if invalid_indexing: + d_weights[idx1 + clas] += d_scores[clas] * value + d_weights[idx2 + clas] += d_scores[clas] * value + else: + d_weights[(clas * nr_weight) + idx1] += d_scores[clas] * value + d_weights[(clas * nr_weight) + idx2] += d_scores[clas] * value d_scores += nr_out keys += length values += length diff --git a/thinc/tests/layers/test_layers_api.py b/thinc/tests/layers/test_layers_api.py index 3ebeb470a..799cad009 100644 --- a/thinc/tests/layers/test_layers_api.py +++ b/thinc/tests/layers/test_layers_api.py @@ -128,6 +128,7 @@ def assert_data_match(Y, out_data): # ("CauchySimilarity.v1", {}, (array2d, array2d), array1d), ("ParametricAttention.v1", {}, ragged, ragged), ("SparseLinear.v1", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d), + ("SparseLinear.v2", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d), ("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint) # fmt: on ] diff --git a/website/docs/api-layers.md b/website/docs/api-layers.md index b1e72f7d7..80406f90b 100644 --- a/website/docs/api-layers.md +++ b/website/docs/api-layers.md @@ -802,6 +802,42 @@ length, describing the concatenated batch of input features and their values. The `lengths` array should have one entry per sequence in the batch, and the sum of the lengths should equal the length of the keys and values array. + + +`SparseLinear` should not be used for new models because it contains an indexing +bug. As a result, only a subset of the weights is used. Use +[`SparseLinear_v2`](#sparselinear_v2) instead. + + + +| Argument | Type | Description | +| ----------- | --------------------------------------------------------- | -------------------------------------------------------- | +| `nO` | Optional[int] | The size of the output vectors. | +| `length` | int | The size of the weights vector, to be tuned empirically. | +| **RETURNS** | Model[Tuple[ArrayXd, ArrayXd, ArrayXd], ArrayXd] | The created layer. | + +```python +https://github.com/explosion/thinc/blob/master/thinc/layers/sparselinear.pyx +``` + +### SparseLinear_v2 {#sparselinear_v2 tag="function"} + + + +- **Input:** Tuple[ArrayXd, ArrayXd, ArrayXd] +- **Output:** ArrayXd +- **Parameters:** W, + b, `length` int + + + +A sparse linear layer using the "hashing trick". Useful for tasks such as text +classification. Inputs to the layer should be a tuple of arrays +`(keys, values, lengths)`, where the `keys` and `values` are arrays of the same +length, describing the concatenated batch of input features and their values. +The `lengths` array should have one entry per sequence in the batch, and the sum +of the lengths should equal the length of the keys and values array. + | Argument | Type | Description | | ----------- | --------------------------------------------------------- | -------------------------------------------------------- | | `nO` | Optional[int] | The size of the output vectors. | From 1bb78b132a97e322cd5244a9c93dac5a4650abaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 2 Sep 2022 10:03:45 +0200 Subject: [PATCH 2/8] SparseLinear_v2: fix issue mapping murmur hashes to array The output of MurMur hashes were mapped to array indices as follows: ``` idx = hash & (nr_weight-1) ``` This works well when `nr_weight` is a power of two. For instance, if we have 16 buckets: ``` idx = hash & 15 idx = hash & 0b1111 ``` However, when the user uses a bucket count that is not a power of two, this breaks down. For instance, if we have 15 buckets: ``` idx = hash & 14 idx = hash & 0b1110 ``` This would mask out all odd indices. We fix this by using the modulus instead. To preserve compatibility with existing models, this change is only added to `SparseLinear_v2`. --- thinc/layers/sparselinear.pyx | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/thinc/layers/sparselinear.pyx b/thinc/layers/sparselinear.pyx index ecc057b60..a1dfa4d13 100644 --- a/thinc/layers/sparselinear.pyx +++ b/thinc/layers/sparselinear.pyx @@ -135,8 +135,12 @@ cdef void set_scoresC(float* scores, for i in range(length): hash1 = MurmurHash3_x86_32_uint64(keys[i], 0) hash2 = MurmurHash3_x86_32_uint64(keys[i], 1) - idx1 = hash1 & (nr_weight-1) - idx2 = hash2 & (nr_weight-1) + if invalid_indexing: + idx1 = hash1 & (nr_weight-1) + idx2 = hash2 & (nr_weight-1) + else: + idx1 = hash1 % nr_weight + idx2 = hash2 % nr_weight value = values[i] for clas in range(nr_out): if invalid_indexing: @@ -160,8 +164,12 @@ cdef void set_gradientC(float* d_weights, for i in range(length): hash1 = MurmurHash3_x86_32_uint64(keys[i], 0) hash2 = MurmurHash3_x86_32_uint64(keys[i], 1) - idx1 = hash1 & (nr_weight-1) - idx2 = hash2 & (nr_weight-1) + if invalid_indexing: + idx1 = hash1 & (nr_weight-1) + idx2 = hash2 & (nr_weight-1) + else: + idx1 = hash1 % nr_weight + idx2 = hash2 % nr_weight value = values[i] for clas in range(nr_out): if invalid_indexing: From 7aba3653fae3a9e004ec47bedf8f5a39539a486c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 2 Sep 2022 18:53:07 +0200 Subject: [PATCH 3/8] Rename `invalid_indexing` to `v1_indexing` --- thinc/layers/sparselinear.pyx | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/thinc/layers/sparselinear.pyx b/thinc/layers/sparselinear.pyx index a1dfa4d13..3082744c8 100644 --- a/thinc/layers/sparselinear.pyx +++ b/thinc/layers/sparselinear.pyx @@ -27,7 +27,7 @@ def SparseLinear(nO: Optional[int] = None, length: int = 2 ** 18): init=init, params={"W": None, "b": None}, dims={"nO": nO, "length": length}, - attrs={"invalid_indexing": True}, + attrs={"v1_indexing": True}, ) @@ -42,7 +42,7 @@ def SparseLinear_v2(nO: Optional[int] = None, length: int = 2 ** 18): init=init, params={"W": None, "b": None}, dims={"nO": nO, "length": length}, - attrs={"invalid_indexing": False}, + attrs={"v1_indexing": False}, ) @@ -86,12 +86,12 @@ def _begin_cpu_update(model, np.ndarray keys, np.ndarray values, np.ndarray leng cdef np.ndarray W = model.get_param("W") cdef np.ndarray b = model.get_param("b") cdef np.ndarray scores = model.ops.alloc((len(lengths), nO)) - cdef bint invalid_indexing = model.attrs["invalid_indexing"] + cdef bint v1_indexing = model.attrs["v1_indexing"] scores += b set_scoresC(scores.data, keys.data, values.data, lengths.data, lengths.shape[0], nO, - W.data, length, invalid_indexing) + W.data, length, v1_indexing) return scores, _finish_linear_update(model, keys, values, lengths) @@ -112,10 +112,10 @@ class _finish_linear_update: cdef np.ndarray keys = self.keys cdef np.ndarray values = self.values cdef np.ndarray lengths = self.lengths - cdef bint invalid_indexing = self.model.attrs["invalid_indexing"] + cdef bint v1_indexing = self.model.attrs["v1_indexing"] set_gradientC(d_weights.data, keys.data, values.data, lengths.data, - lengths.shape[0], nO, &d_scores[0,0], length, invalid_indexing) + lengths.shape[0], nO, &d_scores[0,0], length, v1_indexing) cdef int i, j for i in range(d_scores.shape[0]): for j in range(d_scores.shape[1]): @@ -128,14 +128,14 @@ class _finish_linear_update: cdef void set_scoresC(float* scores, const uint64_t* keys, const float* values, const int32_t* lengths, int batch_size, int nr_out, const float* weights, int nr_weight, - bint invalid_indexing) nogil: + bint v1_indexing) nogil: cdef uint32_t idx1, idx2 cdef uint32_t hash1, hash2 for length in lengths[:batch_size]: for i in range(length): hash1 = MurmurHash3_x86_32_uint64(keys[i], 0) hash2 = MurmurHash3_x86_32_uint64(keys[i], 1) - if invalid_indexing: + if v1_indexing: idx1 = hash1 & (nr_weight-1) idx2 = hash2 & (nr_weight-1) else: @@ -143,7 +143,7 @@ cdef void set_scoresC(float* scores, idx2 = hash2 % nr_weight value = values[i] for clas in range(nr_out): - if invalid_indexing: + if v1_indexing: scores[clas] += weights[idx1 + clas] * value scores[clas] += weights[idx2 + clas] * value else: @@ -157,14 +157,14 @@ cdef void set_scoresC(float* scores, cdef void set_gradientC(float* d_weights, const uint64_t* keys, const float* values, const int32_t* lengths, int batch_size, int nr_out, const float* d_scores, int nr_weight, - bint invalid_indexing) nogil: + bint v1_indexing) nogil: cdef uint32_t idx1, idx2 cdef uint32_t hash1, hash2 for length in lengths[:batch_size]: for i in range(length): hash1 = MurmurHash3_x86_32_uint64(keys[i], 0) hash2 = MurmurHash3_x86_32_uint64(keys[i], 1) - if invalid_indexing: + if v1_indexing: idx1 = hash1 & (nr_weight-1) idx2 = hash2 & (nr_weight-1) else: @@ -172,7 +172,7 @@ cdef void set_gradientC(float* d_weights, idx2 = hash2 % nr_weight value = values[i] for clas in range(nr_out): - if invalid_indexing: + if v1_indexing: d_weights[idx1 + clas] += d_scores[clas] * value d_weights[idx2 + clas] += d_scores[clas] * value else: From 1da488e1c56bc3d48eff6faa59d71a4c3b3e1692 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 2 Sep 2022 18:55:30 +0200 Subject: [PATCH 4/8] Add comment about v1 indexing --- thinc/layers/sparselinear.pyx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/thinc/layers/sparselinear.pyx b/thinc/layers/sparselinear.pyx index 3082744c8..b9a982f4b 100644 --- a/thinc/layers/sparselinear.pyx +++ b/thinc/layers/sparselinear.pyx @@ -125,6 +125,8 @@ class _finish_linear_update: return (self.keys, self.values, self.lengths) +# v1_indexing is invalid and only uses a subset of the weight matrix, v1 +# indexing is provided here for compatibility. See #752 for more information. cdef void set_scoresC(float* scores, const uint64_t* keys, const float* values, const int32_t* lengths, int batch_size, int nr_out, const float* weights, int nr_weight, @@ -154,6 +156,8 @@ cdef void set_scoresC(float* scores, values += length +# v1_indexing is invalid and only uses a subset of the weight matrix, v1 +# indexing is provided here for compatibility. See #752 for more information. cdef void set_gradientC(float* d_weights, const uint64_t* keys, const float* values, const int32_t* lengths, int batch_size, int nr_out, const float* d_scores, int nr_weight, From 1c813c44c4fce6ce4782474cbffdc23f14d4b83e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 2 Sep 2022 19:03:30 +0200 Subject: [PATCH 5/8] Fix incorrect merge fix --- thinc/tests/layers/test_layers_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thinc/tests/layers/test_layers_api.py b/thinc/tests/layers/test_layers_api.py index 38d399c19..873702e7d 100644 --- a/thinc/tests/layers/test_layers_api.py +++ b/thinc/tests/layers/test_layers_api.py @@ -129,7 +129,6 @@ def assert_data_match(Y, out_data): ("ParametricAttention.v1", {}, ragged, ragged), ("SparseLinear.v1", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d), ("SparseLinear.v2", {}, (numpy.asarray([1, 2, 3], dtype="uint64"), array1d, numpy.asarray([1, 1], dtype="i")), array2d), - ("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint) ("remap_ids.v1", {"dtype": "f"}, ["a", 1, 5.0], array2dint), ("remap_ids.v2", {"mapping_table": {}, "column": 1}, numpy.array([[1, 2, 3], [4, 5, 6]]).T, array2dint) # fmt: on From 682a1c0613660547157863a9daae3e9ebef2c168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 14 Nov 2022 13:21:25 +0100 Subject: [PATCH 6/8] Add the `new` tag to the docs --- website/docs/api-layers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/docs/api-layers.md b/website/docs/api-layers.md index 3ce96970a..bf729c56a 100644 --- a/website/docs/api-layers.md +++ b/website/docs/api-layers.md @@ -820,7 +820,7 @@ bug. As a result, only a subset of the weights is used. Use https://github.com/explosion/thinc/blob/master/thinc/layers/sparselinear.pyx ``` -### SparseLinear_v2 {#sparselinear_v2 tag="function"} +### SparseLinear_v2 {#sparselinear_v2 tag="function" new="8.1.6"} From 4892c429747c669886ebcce787e8e6120d4ecfd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 14 Nov 2022 15:01:03 +0100 Subject: [PATCH 7/8] Check that the corrected hash function has the expected distribution --- thinc/tests/layers/test_sparse_linear.py | 35 +++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/thinc/tests/layers/test_sparse_linear.py b/thinc/tests/layers/test_sparse_linear.py index fe3337b30..87c5a3a75 100644 --- a/thinc/tests/layers/test_sparse_linear.py +++ b/thinc/tests/layers/test_sparse_linear.py @@ -1,6 +1,7 @@ +import math import numpy import pytest -from thinc.api import SGD, to_categorical, SparseLinear +from thinc.api import SGD, to_categorical, SparseLinear, SparseLinear_v2 @pytest.fixture @@ -42,3 +43,35 @@ def test_init(): assert scores.shape == (2, 3) d_feats = backprop(scores) assert len(d_feats) == 3 + + +def test_distribution(): + n_class = 10 + length = 2**18 + model = SparseLinear_v2(nO=n_class, length=length).initialize() + + ii64 = numpy.iinfo(numpy.uint64) + lengths = numpy.zeros((2,), dtype="int32") + + for p_nonzero in range(1, 12): + # Clear gradients from the previous iterarion. + model.set_grad("W", 0.0) + + n = 2**p_nonzero + keys = numpy.random.randint(ii64.min, ii64.max, size=(n,), dtype=numpy.uint64) + values = numpy.ones((n,), dtype="f") + lengths[0] = n // 2 + lengths[1] = n // 2 + + # Probability that a bit is set (2 because we use 2 hashes). + p_nonzero = 1 - math.exp(-2 * n / length) + + Y, backprop = model.begin_update((keys, values, lengths)) + backprop(numpy.ones_like(Y)) + + # Check that for each class we have the expected rate of non-zeros. + dW = model.get_grad("W").reshape(n_class, -1) + nonzero_empirical = numpy.count_nonzero(dW, axis=1) / dW.shape[1] + numpy.testing.assert_allclose( + nonzero_empirical, p_nonzero, rtol=1e-4, atol=1e-4 + ) From fa19f12dd6d7e85e71d4e05955a51a828ec201aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 14 Nov 2022 15:11:13 +0100 Subject: [PATCH 8/8] Symbol export fixes --- thinc/api.py | 2 +- thinc/layers/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/thinc/api.py b/thinc/api.py index 0fce9d646..328a9d107 100644 --- a/thinc/api.py +++ b/thinc/api.py @@ -91,7 +91,7 @@ "Dish", "HardSwish", "HardSwishMobilenet", "Swish", "Gelu", "PyTorchWrapper", "PyTorchRNNWrapper", "PyTorchLSTM", "TensorFlowWrapper", "keras_subclass", "MXNetWrapper", - "PyTorchWrapper_v2", "Softmax_v2", + "PyTorchWrapper_v2", "Softmax_v2", "SparseLinear_v2", "add", "bidirectional", "chain", "clone", "concatenate", "noop", "residual", "uniqued", "siamese", "list2ragged", "ragged2list", diff --git a/thinc/layers/__init__.py b/thinc/layers/__init__.py index 786252407..e4ff2c9d7 100644 --- a/thinc/layers/__init__.py +++ b/thinc/layers/__init__.py @@ -100,6 +100,7 @@ "Softmax", "Softmax_v2", "SparseLinear", + "SparseLinear_v2", "TensorFlowWrapper", "add", "bidirectional",