Skip to content

Commit bfb5592

Browse files
committed
Add embed dim assertion
1 parent 984bbe4 commit bfb5592

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

monai/networks/blocks/pos_embed_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def parse(x):
3434

3535
def build_fourier_position_embedding(
3636
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
37-
):
37+
) -> torch.nn.Parameter:
3838
"""
3939
Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
4040
spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
@@ -55,7 +55,12 @@ def build_fourier_position_embedding(
5555
to_tuple = _ntuple(spatial_dims)
5656
grid_size = to_tuple(grid_size)
5757

58-
scales = torch.tensor(scales)
58+
if embed_dim % (2 * spatial_dims) != 0:
59+
raise AssertionError(
60+
f"Embed dimension must be divisible by {2 * spatial_dims} for {spatial_dims}D Fourier feature position embedding"
61+
)
62+
63+
scales: torch.Tensor = torch.as_tensor(scales, dtype=torch.float)
5964
if scales.ndim > 1 and scales.ndim != spatial_dims:
6065
raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims")
6166
if scales.ndim == 0:
@@ -65,15 +70,15 @@ def build_fourier_position_embedding(
6570
gaussians = gaussians * scales
6671

6772
positions = [torch.linspace(0, 1, x) for x in grid_size]
68-
positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), axis=-1)
73+
positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), dim=-1)
6974
positions = positions.flatten(end_dim=-2)
7075

7176
x_proj = (2.0 * torch.pi * positions) @ gaussians.T
7277

73-
pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1)
74-
pos_emb = pos_emb[None, :, :]
78+
pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
79+
pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False)
7580

76-
return nn.Parameter(pos_emb, requires_grad=False)
81+
return pos_emb
7782

7883

7984
def build_sincos_position_embedding(

0 commit comments

Comments
 (0)