diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index a7015d44d..0a865e7b5 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -722,15 +722,15 @@ def as_contig(self, data: ArrayT, dtype: Optional[DTypes] = None) -> ArrayT: return self.xp.ascontiguousarray(data, **kwargs) def sigmoid(self, X: FloatsType, *, inplace: bool = False) -> FloatsType: - # To prevent overflows and help with regularization/numerical stability - X = self.xp.clip(X, -20.0, 20.0) - if inplace: + # To prevent overflows and help with regularization/numerical stability + X = self.xp.clip(X, -20.0, 20.0, out=X) self.xp.exp(-X, out=X) X += 1.0 # type: ignore[assignment] X **= -1.0 # type: ignore[assignment] return cast(FloatsType, X) else: + X = self.xp.clip(X, -20.0, 20.0) return cast(FloatsType, 1.0 / (1.0 + self.xp.exp(-X))) def backprop_sigmoid( @@ -909,7 +909,7 @@ def backprop_relu_k( return self.backprop_clipped_linear(dY, X, max_val=n, inplace=inplace) def hard_sigmoid(self, X: FloatsType, inplace: bool = False) -> FloatsType: - return self.clipped_linear(X, slope=0.2, offset=0.5) + return self.clipped_linear(X, slope=0.2, offset=0.5, inplace=inplace) def backprop_hard_sigmoid( self, dY: FloatsType, X: FloatsType, inplace: bool = False @@ -917,7 +917,7 @@ def backprop_hard_sigmoid( return self.backprop_clipped_linear(dY, X, slope=0.2, offset=0.5) def hard_tanh(self, X: FloatsType, inplace: bool = False) -> FloatsType: - return self.clipped_linear(X, min_val=-1.0, max_val=1.0) + return self.clipped_linear(X, min_val=-1.0, max_val=1.0, inplace=inplace) def backprop_hard_tanh( self, dY: FloatsType, X: FloatsType, inplace: bool = False diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index c8cb6a9df..451f6aeec 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -1270,7 +1270,10 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func): y_thinc = forward(x_thinc) y.backward() assert x_thinc.dtype == y_thinc.dtype - assert ops.xp.isclose(y_thinc, forward(x_thinc, inplace=True), atol=1e-06) + assert y_thinc is not x_thinc + y_think_inplace = forward(x_thinc, inplace=True) + assert y_think_inplace is x_thinc + assert ops.xp.isclose(y_thinc, y_think_inplace, atol=1e-06) assert ops.xp.isclose(y_thinc, y.detach(), atol=1e-06) x_thinc = ops.asarray([x], dtype=dtype) dY_thinc = ops.asarray([1.0], dtype=dtype) @@ -1282,10 +1285,12 @@ def test_compare_activations_to_torch(ops, dtype, x, torch_func): if params == {"dY", "X", "Y"}: dx_thinc = backward(dY_thinc, Y=y_thinc, X=x_thinc) assert dx_thinc.dtype == x_thinc.dtype - assert ops.xp.isclose( - dx_thinc, - backward(dY=dY_thinc_inplace, Y=y_thinc, X=x_thinc, inplace=True), + assert dx_thinc is not dY_thinc + dx_thinc_inplace = backward( + dY=dY_thinc_inplace, Y=y_thinc, X=x_thinc, inplace=True ) + assert dx_thinc_inplace is dY_thinc_inplace + assert ops.xp.isclose(dx_thinc, dx_thinc_inplace) assert ops.xp.isclose(x_torch.grad.item(), float(dx_thinc), atol=1e-06) elif params == {"Y", "dY"}: dx_thinc = backward(dY_thinc, Y=y_thinc)