Skip to content

Commit 76c4391

Browse files
NabJaericspod
authored andcommitted
8564 fourier positional encoding (Project-MONAI#8570)
Fixes Project-MONAI#8564 . ### Description Add Fourier feature positional encodings to `PatchEmbeddingBlock`. It has been shown, that Fourier feature positional encodings are better suited for Anistropic images and videos: https://arxiv.org/abs/2509.02488 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NabJa <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent b4e0fcc commit 76c4391

File tree

3 files changed

+112
-4
lines changed

3 files changed

+112
-4
lines changed

monai/networks/blocks/patchembedding.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,22 @@
1212
from __future__ import annotations
1313

1414
from collections.abc import Sequence
15+
from typing import Optional
1516

1617
import numpy as np
1718
import torch
1819
import torch.nn as nn
1920
import torch.nn.functional as F
2021
from torch.nn import LayerNorm
2122

22-
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
23+
from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding
2324
from monai.networks.layers import Conv, trunc_normal_
2425
from monai.utils import ensure_tuple_rep, optional_import
2526
from monai.utils.module import look_up_option
2627

2728
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
2829
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
29-
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
30+
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"}
3031

3132

3233
class PatchEmbeddingBlock(nn.Module):
@@ -53,6 +54,7 @@ def __init__(
5354
pos_embed_type: str = "learnable",
5455
dropout_rate: float = 0.0,
5556
spatial_dims: int = 3,
57+
pos_embed_kwargs: Optional[dict] = None,
5658
) -> None:
5759
"""
5860
Args:
@@ -65,6 +67,8 @@ def __init__(
6567
pos_embed_type: position embedding layer type.
6668
dropout_rate: fraction of the input units to drop.
6769
spatial_dims: number of spatial dimensions.
70+
pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain
71+
`temperature` and for fourier it can contain `scales`.
6872
"""
6973

7074
super().__init__()
@@ -105,6 +109,8 @@ def __init__(
105109
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
106110
self.dropout = nn.Dropout(dropout_rate)
107111

112+
pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs
113+
108114
if self.pos_embed_type == "none":
109115
pass
110116
elif self.pos_embed_type == "learnable":
@@ -114,7 +120,17 @@ def __init__(
114120
for in_size, pa_size in zip(img_size, patch_size):
115121
grid_size.append(in_size // pa_size)
116122

117-
self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
123+
self.position_embeddings = build_sincos_position_embedding(
124+
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
125+
)
126+
elif self.pos_embed_type == "fourier":
127+
grid_size = []
128+
for in_size, pa_size in zip(img_size, patch_size):
129+
grid_size.append(in_size // pa_size)
130+
131+
self.position_embeddings = build_fourier_position_embedding(
132+
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
133+
)
118134
else:
119135
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")
120136

monai/networks/blocks/pos_embed_utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020

21-
__all__ = ["build_sincos_position_embedding"]
21+
__all__ = ["build_fourier_position_embedding", "build_sincos_position_embedding"]
2222

2323

2424
# From PyTorch internals
@@ -32,6 +32,59 @@ def parse(x):
3232
return parse
3333

3434

35+
def build_fourier_position_embedding(
36+
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
37+
) -> torch.nn.Parameter:
38+
"""
39+
Builds a (Anistropic) Fourier feature position embedding based on the given grid size, embed dimension,
40+
spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
41+
points more distinguishable.
42+
Position embedding is made anistropic by allowing setting different scales for each spatial dimension.
43+
Reference: https://arxiv.org/abs/2509.02488
44+
45+
Args:
46+
grid_size (int | List[int]): The size of the grid in each spatial dimension.
47+
embed_dim (int): The dimension of the embedding.
48+
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
49+
scales (float | List[float]): The scale for every spatial dimension. If a single float is provided,
50+
the same scale is used for all dimensions.
51+
52+
Returns:
53+
pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
54+
"""
55+
56+
to_tuple = _ntuple(spatial_dims)
57+
grid_size_t = to_tuple(grid_size)
58+
if len(grid_size_t) != spatial_dims:
59+
raise ValueError(f"Length of grid_size ({len(grid_size_t)}) must be the same as spatial_dims.")
60+
61+
if embed_dim % 2 != 0:
62+
raise ValueError("embed_dim must be even for Fourier position embedding")
63+
64+
# Ensure scales is a tensor of shape (spatial_dims,)
65+
if isinstance(scales, float):
66+
scales_tensor = torch.full((spatial_dims,), scales, dtype=torch.float)
67+
elif isinstance(scales, (list, tuple)):
68+
if len(scales) != spatial_dims:
69+
raise ValueError(f"Length of scales {len(scales)} does not match spatial_dims {spatial_dims}")
70+
scales_tensor = torch.tensor(scales, dtype=torch.float)
71+
else:
72+
raise TypeError(f"scales must be float or list of floats, got {type(scales)}")
73+
74+
gaussians = torch.randn(embed_dim // 2, spatial_dims, dtype=torch.float32) * scales_tensor
75+
76+
position_indices = [torch.linspace(0, 1, x, dtype=torch.float32) for x in grid_size_t]
77+
positions = torch.stack(torch.meshgrid(*position_indices, indexing="ij"), dim=-1)
78+
positions = positions.flatten(end_dim=-2)
79+
80+
x_proj = (2.0 * torch.pi * positions) @ gaussians.T
81+
82+
pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
83+
pos_emb = nn.Parameter(pos_emb[None, :, :], requires_grad=False)
84+
85+
return pos_emb
86+
87+
3588
def build_sincos_position_embedding(
3689
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
3790
) -> torch.nn.Parameter:

tests/networks/blocks/test_patchembedding.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ def test_sincos_pos_embed(self):
8787

8888
self.assertEqual(net.position_embeddings.requires_grad, False)
8989

90+
def test_fourier_pos_embed(self):
91+
net = PatchEmbeddingBlock(
92+
in_channels=1,
93+
img_size=(32, 32, 32),
94+
patch_size=(8, 8, 8),
95+
hidden_size=96,
96+
num_heads=8,
97+
pos_embed_type="fourier",
98+
dropout_rate=0.5,
99+
)
100+
101+
self.assertEqual(net.position_embeddings.requires_grad, False)
102+
90103
def test_learnable_pos_embed(self):
91104
net = PatchEmbeddingBlock(
92105
in_channels=1,
@@ -101,6 +114,32 @@ def test_learnable_pos_embed(self):
101114
self.assertEqual(net.position_embeddings.requires_grad, True)
102115

103116
def test_ill_arg(self):
117+
with self.assertRaises(ValueError):
118+
PatchEmbeddingBlock(
119+
in_channels=1,
120+
img_size=(128, 128, 128),
121+
patch_size=(16, 16, 16),
122+
hidden_size=128,
123+
num_heads=12,
124+
proj_type="conv",
125+
dropout_rate=0.1,
126+
pos_embed_type="fourier",
127+
pos_embed_kwargs=dict(scales=[1.0, 1.0]),
128+
)
129+
130+
with self.assertRaises(ValueError):
131+
PatchEmbeddingBlock(
132+
in_channels=1,
133+
img_size=(128, 128),
134+
patch_size=(16, 16),
135+
hidden_size=128,
136+
num_heads=12,
137+
proj_type="conv",
138+
dropout_rate=0.1,
139+
pos_embed_type="fourier",
140+
pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]),
141+
)
142+
104143
with self.assertRaises(ValueError):
105144
PatchEmbeddingBlock(
106145
in_channels=1,

0 commit comments

Comments
 (0)