diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index 081738295..7525e6eed 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -396,7 +396,8 @@ class NumpyOps(Ops): assert O != 0 cdef np.ndarray maxes - cdef np.ndarray which = self.alloc(shape=(B, O), dtype="i", zeros=False) + # Needs to be zero-initialized as we start by assuming that the first element is the max value. + cdef np.ndarray which = self.alloc(shape=(B, O), dtype="i", zeros=True) if reals2d_ft is float2d_t: maxes = self.alloc(shape=(B, O), dtype="float32", zeros=False) cpu_reduce_max(maxes.data, which.data, &X[0, 0], &lengths[0], B, T, O) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 68dd33df7..a91e796cd 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -763,6 +763,9 @@ def test_reduce_max_sm(ops, dtype): lengths = ops.xp.array([2, 2, 2], dtype="i") maxes, which = ops.reduce_max(X, lengths) assert maxes.dtype == dtype + assert ops.xp.all(which >= 0) + assert ops.xp.all(which < X.shape[0]) + start = 0 for i, length in enumerate(lengths): truth = X[start : start + length].max(axis=0) @@ -781,6 +784,9 @@ def test_reduce_max(ops, dtype): # m[1, 3] = 3 maxes, which = ops.reduce_max(m, lengths) assert maxes.dtype == dtype + assert ops.xp.all(which >= 0) + assert ops.xp.all(which < m.shape[0]) + start = 0 for i, length in enumerate(lengths): truth = m[start : start + length].max(axis=0)