diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index cc432382ba15..e90bbc37c7ad 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -1328,8 +1328,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding): def __init__(self, num_positions, embedding_dim, padding_idx=None): super().__init__(num_positions, embedding_dim) - if embedding_dim % 2 != 0: - raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported") self.weight = self._init_weight(self.weight) @staticmethod @@ -1342,10 +1340,11 @@ def _init_weight(out: nn.Parameter): position_enc = np.array( [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] ) - out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos - out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) out.detach_() - out.requires_grad = False return out @torch.no_grad() diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 2f085eb4981c..1b09a814f95c 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -620,8 +620,8 @@ def test_positional_emb_cache_logic(self): self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist()) def test_odd_embed_dim(self): - with self.assertRaises(NotImplementedError): - SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device) + # odd embedding_dim is allowed + SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device) # odd num_positions is allowed SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)