diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index e5a232d9b..2386e21cc 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -1001,6 +1001,8 @@ def erf(self, X: FloatsType) -> FloatsType: return out def sechsq(self, X: FloatsType) -> FloatsType: + # Avoid overflow in cosh. Clipping at |20| has an error of 1.7e-17. + X = self.xp.clip(X, -20.0, 20.0) return (1 / self.xp.cosh(X)) ** 2 def gelu_approx(self, X: FloatsType, inplace: bool = False) -> FloatsType: diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index e095142b1..2c7cad2ca 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -62,7 +62,7 @@ def torch_hard_swish_mobilenet(x): return torch.nn.functional.hardswish(x) def torch_sigmoid(x): - return torch.nn.functional.sigmoid(x) + return torch.sigmoid(x) # https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py#L37 def torch_gelu_approx(x):