Skip to content

Commit 93c9354

Browse files
committed
Add Fourier feature positional encoding (Project-MONAI#8564)
1 parent 725c8de commit 93c9354

File tree

3 files changed

+99
-4
lines changed

3 files changed

+99
-4
lines changed

monai/networks/blocks/patchembedding.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
import torch.nn.functional as F
2020
from torch.nn import LayerNorm
2121

22-
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
22+
from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding
2323
from monai.networks.layers import Conv, trunc_normal_
2424
from monai.utils import ensure_tuple_rep, optional_import
2525
from monai.utils.module import look_up_option
2626

2727
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
2828
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
29-
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
29+
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"}
3030

3131

3232
class PatchEmbeddingBlock(nn.Module):
@@ -53,6 +53,7 @@ def __init__(
5353
pos_embed_type: str = "learnable",
5454
dropout_rate: float = 0.0,
5555
spatial_dims: int = 3,
56+
pos_embed_kwargs: dict = {},
5657
) -> None:
5758
"""
5859
Args:
@@ -65,6 +66,8 @@ def __init__(
6566
pos_embed_type: position embedding layer type.
6667
dropout_rate: fraction of the input units to drop.
6768
spatial_dims: number of spatial dimensions.
69+
pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain
70+
`temperature` and for fourier it can contain `scales`.
6871
"""
6972

7073
super().__init__()
@@ -114,7 +117,17 @@ def __init__(
114117
for in_size, pa_size in zip(img_size, patch_size):
115118
grid_size.append(in_size // pa_size)
116119

117-
self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
120+
self.position_embeddings = build_sincos_position_embedding(
121+
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
122+
)
123+
elif self.pos_embed_type == "fourier":
124+
grid_size = []
125+
for in_size, pa_size in zip(img_size, patch_size):
126+
grid_size.append(in_size // pa_size)
127+
128+
self.position_embeddings = build_fourier_position_embedding(
129+
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
130+
)
118131
else:
119132
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")
120133

monai/networks/blocks/pos_embed_utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020

21-
__all__ = ["build_sincos_position_embedding"]
21+
__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"]
2222

2323

2424
# From PyTorch internals
@@ -32,6 +32,50 @@ def parse(x):
3232
return parse
3333

3434

35+
def build_fourier_position_embedding(
36+
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
37+
):
38+
"""
39+
Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
40+
spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
41+
points more distinguishable.
42+
Reference: https://arxiv.org/abs/2509.02488
43+
44+
Args:
45+
grid_size (List[int]): The size of the grid in each spatial dimension.
46+
embed_dim (int): The dimension of the embedding.
47+
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
48+
scales (List[float]): The scale for every spatial dimension. If a single float is provided,
49+
the same scale is used for all dimensions.
50+
51+
Returns:
52+
pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
53+
"""
54+
55+
to_tuple = _ntuple(spatial_dims)
56+
grid_size = to_tuple(grid_size)
57+
58+
scales = torch.tensor(scales)
59+
if scales.ndim > 1 and scales.ndim != spatial_dims:
60+
raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims")
61+
if scales.ndim == 0:
62+
scales = scales.repeat(spatial_dims)
63+
64+
gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
65+
gaussians = gaussians * scales
66+
67+
positions = [torch.linspace(0, 1, x) for x in grid_size]
68+
positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), axis=-1)
69+
positions = positions.flatten(end_dim=-2)
70+
71+
x_proj = (2.0 * torch.pi * positions) @ gaussians.T
72+
73+
pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1)
74+
pos_emb = pos_emb[None, :, :]
75+
76+
return nn.Parameter(pos_emb, requires_grad=False)
77+
78+
3579
def build_sincos_position_embedding(
3680
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
3781
) -> torch.nn.Parameter:

tests/networks/blocks/test_patchembedding.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ def test_sincos_pos_embed(self):
8787

8888
self.assertEqual(net.position_embeddings.requires_grad, False)
8989

90+
def test_fourier_pos_embed(self):
91+
net = PatchEmbeddingBlock(
92+
in_channels=1,
93+
img_size=(32, 32, 32),
94+
patch_size=(8, 8, 8),
95+
hidden_size=96,
96+
num_heads=8,
97+
pos_embed_type="fourier",
98+
dropout_rate=0.5,
99+
)
100+
101+
self.assertEqual(net.position_embeddings.requires_grad, False)
102+
90103
def test_learnable_pos_embed(self):
91104
net = PatchEmbeddingBlock(
92105
in_channels=1,
@@ -101,6 +114,31 @@ def test_learnable_pos_embed(self):
101114
self.assertEqual(net.position_embeddings.requires_grad, True)
102115

103116
def test_ill_arg(self):
117+
with self.assertRaises(ValueError):
118+
PatchEmbeddingBlock(
119+
in_channels=1,
120+
img_size=(128, 128, 128),
121+
patch_size=(16, 16, 16),
122+
hidden_size=128,
123+
num_heads=12,
124+
proj_type="conv",
125+
dropout_rate=5.0,
126+
pos_embed_type="fourier",
127+
pos_embed_kwargs=dict(scales=[1.0, 1.0]),
128+
)
129+
130+
PatchEmbeddingBlock(
131+
in_channels=1,
132+
img_size=(128, 128),
133+
patch_size=(16, 16),
134+
hidden_size=128,
135+
num_heads=12,
136+
proj_type="conv",
137+
dropout_rate=5.0,
138+
pos_embed_type="fourier",
139+
pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]),
140+
)
141+
104142
with self.assertRaises(ValueError):
105143
PatchEmbeddingBlock(
106144
in_channels=1,

0 commit comments

Comments
 (0)