diff --git a/thinc/api.py b/thinc/api.py index 64ed11703..c07518a41 100644 --- a/thinc/api.py +++ b/thinc/api.py @@ -27,7 +27,7 @@ from .layers import CauchySimilarity, ParametricAttention, Logistic from .layers import resizable, sigmoid_activation, Sigmoid, SparseLinear from .layers import ClippedLinear, ReluK, HardTanh, HardSigmoid -from .layers import HardSwish, HardSwishMobilenet, Swish, Gelu +from .layers import Dish, HardSwish, HardSwishMobilenet, Swish, Gelu from .layers import PyTorchWrapper, PyTorchRNNWrapper, PyTorchLSTM from .layers import TensorFlowWrapper, keras_subclass, MXNetWrapper from .layers import PyTorchWrapper_v2, Softmax_v2 diff --git a/thinc/backends/_custom_kernels.cu b/thinc/backends/_custom_kernels.cu index a0017c0e4..9c9fece1e 100644 --- a/thinc/backends/_custom_kernels.cu +++ b/thinc/backends/_custom_kernels.cu @@ -161,6 +161,20 @@ __global__ void clipped_linear(T* Y, const T* X, double slope, double offset, do } +template +__global__ void dish(T* Y, const T* X, int N) +{ + int _loop_start = blockIdx.x * blockDim.x + threadIdx.x; + int _loop_stride = blockDim.x * gridDim.x; + + for (int i = _loop_start; i < N; i += _loop_stride) + { + T x = X[i]; + Y[i] = 0.5 * x * (x / sqrt(1 + x * x) + 1); + } +} + + template __global__ void gelu(T* Y, const T* X, double threshold, int N) { @@ -414,6 +428,23 @@ __global__ void backprop_hard_swish_mobilenet(T* dX, const T* dY, const T* X, in } +template +__global__ void backprop_dish(T* dX, const T* dY, const T* X, int N) +{ + + int _loop_start = blockIdx.x * blockDim.x + threadIdx.x; + int _loop_stride = blockDim.x * gridDim.x; + + for (int i = _loop_start; i < N; i += _loop_stride) + { + T x = X[i]; + T x_sq = x * x; + T x_sq_plus_one = x_sq + 1.0; + dX[i] = dY[i] * (x/sqrt(x_sq_plus_one) - (0.5 * x * x_sq) + / pow(x_sq_plus_one, static_cast(1.5)) + 0.5); + } +} + template __global__ void backprop_gelu(T* dX, const T* dY, const T* X, diff --git a/thinc/backends/_custom_kernels.py b/thinc/backends/_custom_kernels.py index d2d1ea133..859405495 100644 --- a/thinc/backends/_custom_kernels.py +++ b/thinc/backends/_custom_kernels.py @@ -10,6 +10,8 @@ KERNELS_LIST = [ "backprop_clipped_linear", "backprop_clipped_linear", + "backprop_dish", + "backprop_dish", "backprop_gelu", "backprop_gelu", "backprop_hard_swish", @@ -32,6 +34,8 @@ "backprop_swish", "clipped_linear", "clipped_linear", + "dish", + "dish", "gather_add", "gather_add", "gelu", @@ -78,6 +82,8 @@ def compile_mmh(src): clipped_linear_kernel_float = _get_kernel("clipped_linear") clipped_linear_kernel_double = _get_kernel("clipped_linear") +dish_kernel_float = _get_kernel("dish") +dish_kernel_double = _get_kernel("dish") gather_add_kernel_float = _get_kernel("gather_add") gather_add_kernel_double = _get_kernel("gather_add") gelu_kernel_float = _get_kernel("gelu") @@ -98,6 +104,8 @@ def compile_mmh(src): backprop_clipped_linear_kernel_double = _get_kernel("backprop_clipped_linear") backprop_clipped_linear_kernel_float = _get_kernel("backprop_clipped_linear") +backprop_dish_kernel_double = _get_kernel("backprop_dish") +backprop_dish_kernel_float = _get_kernel("backprop_dish") backprop_gelu_kernel_double = _get_kernel("backprop_gelu") backprop_gelu_kernel_float = _get_kernel("backprop_gelu") backprop_hard_swish_kernel_double = _get_kernel("backprop_hard_swish") @@ -199,6 +207,19 @@ def gather_add(table, indices, *, threads_per_block=128, num_blocks=128): return out +def dish(X, *, inplace=False, threads_per_block=128, num_blocks=128): + _is_float_array(X) + + out = X + if not inplace: + out = _alloc_like(X, zeros=False) + if X.dtype == "float32": + dish_kernel_float((num_blocks,), (threads_per_block,), (out, X, X.size)) + else: + dish_kernel_double((num_blocks,), (threads_per_block,), (out, X, X.size)) + return out + + def gelu(X, *, inplace=False, threshold=6.0, threads_per_block=128, num_blocks=128): _is_float_array(X) @@ -483,6 +504,33 @@ def backprop_hard_swish_mobilenet( return out +def backprop_dish( + dY, + X, + *, + inplace: bool = False, + threads_per_block=128, + num_blocks=128, +): + _is_float_array(dY) + _is_float_array(X, shape=dY.shape) + + out = dY + if not inplace: + out = _alloc_like(dY, zeros=False) + + if dY.dtype == "float32": + backprop_dish_kernel_float( + (num_blocks,), (threads_per_block,), (out, dY, X, out.size) + ) + else: + backprop_dish_kernel_double( + (num_blocks,), (threads_per_block,), (out, dY, X, out.size) + ) + + return out + + def backprop_gelu( dY, X, diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 924bfe955..6d263c155 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -36,6 +36,18 @@ def gather_add(self, table, indices): else: return super().gather_add(table, indices) + def dish(self, X, inplace=False): + if X.dtype in ("float32", "float64"): + return _custom_kernels.dish(X, inplace=inplace) + else: + return super().dish(X, inplace=inplace) + + def backprop_dish(self, dY, X, inplace=False): + if X.dtype == dY.dtype and X.dtype in ("float32", "float64"): + return _custom_kernels.backprop_dish(dY, X, inplace=inplace) + else: + return super().backprop_dish(dY, X, inplace=inplace) + def gelu(self, X, inplace=False): if X.dtype in ("float32", "float64"): return _custom_kernels.gelu(X, inplace=inplace, threshold=6.0) diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index c9fb10aae..f0c05de42 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -976,6 +976,35 @@ def backprop_hard_swish_mobilenet( return dY return dX * dY + def dish(self, X: FloatsXdT, inplace: bool = False) -> FloatsXdT: + tmp = self.xp.square(X) + tmp += 1.0 + self.xp.sqrt(tmp, out=tmp) + tmp = X / tmp + tmp += 1 + tmp *= 0.5 + if inplace: + X *= tmp + return X + else: + return X * tmp + + def backprop_dish( + self, dY: FloatsXdT, X: FloatsXdT, inplace: bool = False + ) -> FloatsXdT: + x_sq = self.xp.square(X) + x_sq_plus_one = x_sq + 1.0 + deriv = X / self.xp.sqrt(x_sq_plus_one) + second = 0.5 * X * x_sq + second /= x_sq_plus_one**1.5 + deriv -= second + deriv += 0.5 + if inplace: + dY *= deriv + return dY + else: + return dY * deriv + # Code snippet taken from: # https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/ def erf(self, X: FloatsXdT) -> FloatsXdT: diff --git a/thinc/layers/__init__.py b/thinc/layers/__init__.py index b37e38a7a..73fa88f4e 100644 --- a/thinc/layers/__init__.py +++ b/thinc/layers/__init__.py @@ -1,5 +1,6 @@ # Weights layers from .cauchysimilarity import CauchySimilarity +from .dish import Dish from .dropout import Dropout from .embed import Embed from .expand_window import expand_window diff --git a/thinc/layers/dish.py b/thinc/layers/dish.py new file mode 100644 index 000000000..b085946b3 --- /dev/null +++ b/thinc/layers/dish.py @@ -0,0 +1,66 @@ +from typing import Tuple, Optional, Callable, cast + +from ..config import registry +from ..model import Model +from .chain import chain +from .layernorm import LayerNorm +from .dropout import Dropout +from ..types import Floats1d, Floats2d +from ..util import partial, get_width +from ..initializers import he_normal_init, zero_init + + +@registry.layers("Dish.v1") +def Dish( + nO: Optional[int] = None, + nI: Optional[int] = None, + *, + init_W: Callable = he_normal_init, + init_b: Callable = zero_init, + dropout: Optional[float] = None, + normalize: bool = False, +) -> Model[Floats2d, Floats2d]: + model: Model[Floats2d, Floats2d] = Model( + "dish", + forward, + init=partial(init, init_W, init_b), + dims={"nO": nO, "nI": nI}, + params={"W": None, "b": None}, + ) + if normalize: + model = chain(model, LayerNorm(nI=nO)) + if dropout is not None: + model = chain(model, cast(Model[Floats2d, Floats2d], Dropout(dropout))) + return model + + +def forward( + model: Model[Floats2d, Floats2d], X: Floats2d, is_train: bool +) -> Tuple[Floats2d, Callable]: + W = cast(Floats2d, model.get_param("W")) + b = cast(Floats1d, model.get_param("b")) + Y_preact = model.ops.affine(X, W, b) + Y = model.ops.dish(Y_preact) + + def backprop(dY: Floats2d) -> Floats2d: + dY = model.ops.backprop_dish(dY, X, inplace=False) + model.inc_grad("b", dY.sum(axis=0)) + model.inc_grad("W", model.ops.gemm(dY, X, trans1=True)) + return model.ops.gemm(dY, W) + + return Y, backprop + + +def init( + init_W: Callable, + init_b: Callable, + model: Model[Floats2d, Floats2d], + X: Optional[Floats2d] = None, + Y: Optional[Floats2d] = None, +) -> None: + if X is not None: + model.set_dim("nI", get_width(X)) + if Y is not None: + model.set_dim("nO", get_width(Y)) + model.set_param("W", init_W(model.ops, (model.get_dim("nO"), model.get_dim("nI")))) + model.set_param("b", init_b(model.ops, (model.get_dim("nO"),))) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index 7de864623..7a9bb1961 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -64,6 +64,9 @@ def torch_hard_swish_mobilenet(x): def torch_sigmoid(x): return torch.sigmoid(x) + def torch_dish(x): + return 0.5 * x * (x / (1 + x * x).sqrt() + 1) + # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L37 def torch_gelu_approx(x): return ( @@ -89,6 +92,7 @@ def torch_gelu(x): ("swish", torch_swish), ("hard_swish", torch_hard_swish), ("hard_swish_mobilenet", torch_hard_swish_mobilenet), + ("dish", torch_dish), ("gelu_approx", torch_gelu_approx), ("gelu", torch_gelu), ("sigmoid", torch_sigmoid), @@ -1043,6 +1047,7 @@ def test_mish(ops, X): "op", [ "backprop_clipped_linear", + "backprop_dish", "backprop_gelu", "backprop_gelu_approx", "backprop_hard_sigmoid", @@ -1160,6 +1165,16 @@ def test_gelu_approx(ops, X): assert not ops.xp.isnan(Y).any() +@pytest.mark.parametrize("ops", ALL_OPS) +@settings(max_examples=MAX_EXAMPLES, deadline=None) +@given(X=strategies.arrays_BI()) +def test_dish(ops, X): + X = ops.asarray(X) + Y = ops.dish(X) + assert Y.shape == X.shape + assert not ops.xp.isnan(Y).any() + + @pytest.mark.parametrize("ops", ALL_OPS) @settings(max_examples=MAX_EXAMPLES, deadline=None) @given(X=strategies.arrays_BI()) @@ -1350,8 +1365,8 @@ def test_ngrams(): @pytest.mark.parametrize("dtype", ["float32", "float64"]) @pytest.mark.parametrize("torch_func", TORCH_FUNCS) @settings(max_examples=MAX_EXAMPLES, deadline=None) -@given(x=strategies.floats(min_value=-30, max_value=30)) -def test_compare_activations_to_torch(ops, dtype, x, torch_func): +@given(x=strategies.floats(min_value=-30, max_value=30), dY=strategies.floats(min_value=-1, max_value=1)) +def test_compare_activations_to_torch(ops, dtype, x, dY, torch_func): import torch func_name, pytorch_func = torch_func @@ -1369,9 +1384,9 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func): y_think_inplace = forward(x_thinc, inplace=True) assert y_think_inplace is x_thinc assert ops.xp.isclose(y_thinc, y_think_inplace, atol=1e-06) - assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-06) + assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-05) x_thinc = ops.asarray([x], dtype=dtype) - dY_thinc = ops.asarray([1.0], dtype=dtype) + dY_thinc = ops.asarray([dY], dtype=dtype) dY_thinc_inplace = dY_thinc.copy() s = inspect.signature(backward) @@ -1386,7 +1401,7 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func): ) assert dx_thinc_inplace is dY_thinc_inplace assert ops.xp.isclose(dx_thinc, dx_thinc_inplace) - assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06) + assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06) elif params == {"Y", "dY"}: dx_thinc = backward(dY_thinc, Y=y_thinc) assert dx_thinc.dtype == x_thinc.dtype @@ -1394,7 +1409,7 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func): dx_thinc, backward(dY=dY_thinc_inplace, Y=y_thinc, inplace=True), ) - assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06) + assert ops.xp.isclose(x_torch.grad.item() * dY, float(dx_thinc), atol=1e-06) elif params == {"dY", "X"}: dx_thinc = backward(dY_thinc, X=x_thinc) assert dx_thinc.dtype == x_thinc.dtype @@ -1402,7 +1417,7 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func): dx_thinc, backward(dY=dY_thinc_inplace, X=x_thinc, inplace=True) ) assert ops.xp.isclose( - x_torch.grad.item(), float(backward(dY_thinc, X=x_thinc)), atol=1e-06 + x_torch.grad.item() * dY, float(backward(dY_thinc, X=x_thinc)), atol=1e-06 ) else: raise NotImplementedError( diff --git a/thinc/tests/layers/test_layers_api.py b/thinc/tests/layers/test_layers_api.py index 5c922f074..3ebeb470a 100644 --- a/thinc/tests/layers/test_layers_api.py +++ b/thinc/tests/layers/test_layers_api.py @@ -57,6 +57,8 @@ def assert_data_match(Y, out_data): TEST_CASES_SUMMABLE = [ # Array to array + ("Dish.v1", {}, array2d, array2d), + ("Dish.v1", {"nO": 4, "nI": 4}, array2d, array2d), ("Dropout.v1", {}, array2d, array2d), ("LayerNorm.v1", {}, array2d, array2d), ("Linear.v1", {}, array2d, array2d), diff --git a/website/docs/api-backends.md b/website/docs/api-backends.md index f2cdb03ce..c5a54cff8 100644 --- a/website/docs/api-backends.md +++ b/website/docs/api-backends.md @@ -927,6 +927,47 @@ Backpropagate the Swish activation | `inplace` | bool | If `True`, the `dY` array is modified in place. | | **RETURNS** | FloatsXd | The gradient of the input. | +### Ops.dish {#dish tag="method" new="8.1.1"} + + + +- **default:** +- **numpy:** +- **cupy:** + + + +Dish or "Daniël's Swish-like activation" is an activation function with a non-monotinic shape similar to +[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on +elementary functions like `exp` or `erf`, making it much +[faster to compute](https://twitter.com/danieldekok/status/1484898130441166853) +in most cases. + +| Argument | Type | Description | +| ----------- | ----------------- | ------------------------------------------ | +| `X` | FloatsXd | The inputs. | +| `inplace` | bool | If `True`, the array is modified in place. | +| **RETURNS** | FloatsXd | The outputs. | + +### Ops.backprop_dish {#backprop_dish tag="method" new="8.1.1"} + + + +- **default:** +- **numpy:** +- **cupy:** + + + +Backpropagate the Dish activation. + +| Argument | Type | Description | +| ----------- | ----------------- | ----------------------------------------------- | +| `dY` | FloatsXd | Gradients of the output array. | +| `X` | FloatsXd | The inputs to the forward pass. | +| `inplace` | bool | If `True`, the `dY` array is modified in place. | +| **RETURNS** | FloatsXd | The gradient of the input. | + ### Ops.gelu {#gelu tag="method"} diff --git a/website/docs/api-layers.md b/website/docs/api-layers.md index 1c43a9d7a..a9fc9a385 100644 --- a/website/docs/api-layers.md +++ b/website/docs/api-layers.md @@ -44,6 +44,39 @@ Primarily used within [`siamese`](#siamese) neural networks. https://github.com/explosion/thinc/blob/master/thinc/layers/cauchysimilarity.py ``` +### Dish {#dish tag="function"} + + + +- **Input:** Floats2d +- **Output:** Floats2d +- **Parameters:** W, + b + + + +A dense layer with the Dish activation function. Dish or "Daniël's Swish-like +activation" is an activation function with a non-monotinic shape similar to +[GELU](#gelu), [Swish](#swish) and [Mish](#mish). However, Dish does not rely on +elementary functions like `exp` or `erf`, making it much +[faster to compute](https://twitter.com/danieldekok/status/1484898130441166853) +in most cases. + +| Argument | Type | Description | +| -------------- | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------ | +| `nO` | Optional[int] | The size of the output vectors. | +| `nI` | Optional[int] | The size of the input vectors. | +| _keyword-only_ | | | +| `init_W` | Callable | A function to initialize the weights matrix. Defaults to [`he_normal_init`](/docs/api-initializers#he_normal_init) | +| `init_b` | Callable | A function to initialize the bias vector. Defaults to [`zero_init`](/docs/api-initializers#zero_init). | +| `dropout` | Optional[float] | Dropout rate to avoid overfitting. | +| `normalize` | bool | Whether or not to apply [layer normalization](#layernorm). Defaults to `False`. | +| **RETURNS** | Model[Floats2d, Floats2d] | The created dense layer. | + +```python +https://github.com/explosion/thinc/blob/master/thinc/layers/dish.py +``` + ### Dropout {#dropout tag="function"} @@ -835,8 +868,8 @@ https://github.com/explosion/thinc/blob/master/thinc/layers/reduce_last.py Pooling layer that reduces the dimensions of the data by selecting the maximum -value for each feature. A `ValueError` is raised if any element in `lengths` -is zero. +value for each feature. A `ValueError` is raised if any element in `lengths` is +zero. | Argument | Type | Description | | ----------- | -------------------------------- | -------------------------- |