@@ -34,7 +34,7 @@ def parse(x):
3434
3535def build_fourier_position_embedding (
3636 grid_size : Union [int , List [int ]], embed_dim : int , spatial_dims : int = 3 , scales : Union [float , List [float ]] = 1.0
37- ):
37+ ) -> torch . nn . Parameter :
3838 """
3939 Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
4040 spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
@@ -55,7 +55,12 @@ def build_fourier_position_embedding(
5555 to_tuple = _ntuple (spatial_dims )
5656 grid_size = to_tuple (grid_size )
5757
58- scales = torch .tensor (scales )
58+ if embed_dim % (2 * spatial_dims ) != 0 :
59+ raise AssertionError (
60+ f"Embed dimension must be divisible by { 2 * spatial_dims } for { spatial_dims } D Fourier feature position embedding"
61+ )
62+
63+ scales : torch .Tensor = torch .as_tensor (scales , dtype = torch .float )
5964 if scales .ndim > 1 and scales .ndim != spatial_dims :
6065 raise ValueError ("Scales must be either a float or a list of floats with length equal to spatial_dims" )
6166 if scales .ndim == 0 :
@@ -65,15 +70,15 @@ def build_fourier_position_embedding(
6570 gaussians = gaussians * scales
6671
6772 positions = [torch .linspace (0 , 1 , x ) for x in grid_size ]
68- positions = torch .stack (torch .meshgrid (* positions , indexing = "ij" ), axis = - 1 )
73+ positions = torch .stack (torch .meshgrid (* positions , indexing = "ij" ), dim = - 1 )
6974 positions = positions .flatten (end_dim = - 2 )
7075
7176 x_proj = (2.0 * torch .pi * positions ) @ gaussians .T
7277
73- pos_emb = torch .cat ([torch .sin (x_proj ), torch .cos (x_proj )], axis = - 1 )
74- pos_emb = pos_emb [None , :, :]
78+ pos_emb = torch .cat ([torch .sin (x_proj ), torch .cos (x_proj )], dim = - 1 )
79+ pos_emb = nn . Parameter ( pos_emb [None , :, :], requires_grad = False )
7580
76- return nn . Parameter ( pos_emb , requires_grad = False )
81+ return pos_emb
7782
7883
7984def build_sincos_position_embedding (
0 commit comments