diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 85988e3523..054dc3c80b 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -149,13 +149,10 @@ def silu_grad(x): self.const = float(silu(threshold)) def forward(self, x: torch.Tensor) -> torch.Tensor: - silu_part = F.silu(x) - mask = x >= self.threshold - if torch.any(mask): - tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const - return torch.where(x < self.threshold, silu_part, tanh_part) - else: - return silu_part + sig = torch.sigmoid(x) + silu = x * sig + tanh = torch.tanh(self.slope * (x - self.threshold)) + self.const + return torch.where(x >= self.threshold, tanh, silu) class ActivationFn(torch.nn.Module):