@@ -820,9 +820,8 @@ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.
820820
821821 return rel_pos_resized [relative_coords .long ()]
822822
823- def add_decomposed_rel_pos (
823+ def get_decomposed_rel_pos (
824824 self ,
825- attn : torch .Tensor ,
826825 query : torch .Tensor ,
827826 rel_pos_h : torch .Tensor ,
828827 rel_pos_w : torch .Tensor ,
@@ -834,8 +833,6 @@ def add_decomposed_rel_pos(
834833 https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
835834
836835 Args:
837- attn (`torch.Tensor`):
838- attention map.
839836 query (`torch.Tensor`):
840837 query q in the attention layer with shape (batch_size, query_height * query_width, channel).
841838 rel_pos_h (`torch.Tensor`):
@@ -848,8 +845,8 @@ def add_decomposed_rel_pos(
848845 spatial sequence size of key k with (key_height, key_width).
849846
850847 Returns:
851- attn (`torch.Tensor`):
852- attention map with added relative positional embeddings.
848+ decomposed_rel_pos (`torch.Tensor`):
849+ decomposed relative position embeddings.
853850 """
854851 query_height , query_width = q_size
855852 key_height , key_width = k_size
@@ -860,10 +857,10 @@ def add_decomposed_rel_pos(
860857 reshaped_query = query .reshape (batch_size , query_height , query_width , dim )
861858 rel_h = torch .einsum ("bhwc,hkc->bhwk" , reshaped_query , relative_position_height )
862859 rel_w = torch .einsum ("bhwc,wkc->bhwk" , reshaped_query , relative_position_width )
863- attn = attn . reshape ( batch_size , query_height , query_width , key_height , key_width )
864- attn = attn + rel_h [:, :, :, :, None ] + rel_w [:, :, :, None , :]
865- attn = attn . reshape ( batch_size , query_height * query_width , key_height * key_width )
866- return attn
860+
861+ decomposed_rel_pos = rel_h [:, :, :, :, None ] + rel_w [:, :, :, None , :]
862+
863+ return decomposed_rel_pos
867864
868865 def forward (self , hidden_states : torch .Tensor , output_attentions = False ) -> torch .Tensor :
869866 batch_size , height , width , _ = hidden_states .shape
@@ -879,9 +876,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
879876 attn_weights = (query * self .scale ) @ key .transpose (- 2 , - 1 )
880877
881878 if self .use_rel_pos :
882- attn_weights = self .add_decomposed_rel_pos (
883- attn_weights , query , self .rel_pos_h , self .rel_pos_w , (height , width ), (height , width )
879+ decomposed_rel_pos = self .get_decomposed_rel_pos (
880+ query , self .rel_pos_h , self .rel_pos_w , (height , width ), (height , width )
884881 )
882+ decomposed_rel_pos = decomposed_rel_pos .reshape_as (attn_weights )
883+ attn_weights = attn_weights + decomposed_rel_pos
885884
886885 attn_weights = torch .nn .functional .softmax (attn_weights , dtype = torch .float32 , dim = - 1 ).to (query .dtype )
887886
@@ -909,47 +908,19 @@ class SamVisionSdpaAttention(SamVisionAttention):
909908 def __init__ (self , config , window_size ):
910909 super ().__init__ (config , window_size )
911910
912- def add_decomposed_rel_pos (
913- self ,
914- query : torch .Tensor ,
915- rel_pos_h : torch .Tensor ,
916- rel_pos_w : torch .Tensor ,
917- q_size : Tuple [int , int ],
918- k_size : Tuple [int , int ],
919- ) -> torch .Tensor :
920- """
921- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
922- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
923- This method is reimplemented to follow the implementation in:
924- https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
925- This implementation is more memory efficient when using SDPA in the forward method.
926- Args:
927- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
928- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
929- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
930- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
931- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
932-
933- Returns:
934- attn (Tensor): attention map with added relative positional embeddings.
935- """
936- query_height , query_width = q_size
937- key_height , key_width = k_size
938- relative_position_height = self .get_rel_pos (query_height , key_height , rel_pos_h )
939- relative_position_width = self .get_rel_pos (query_width , key_width , rel_pos_w )
940-
941- batch_size , _ , dim = query .shape
942- reshaped_query = query .reshape (batch_size , query_height , query_width , dim )
943- rel_h = torch .einsum ("bhwc,hkc->bhwk" , reshaped_query , relative_position_height )
944- rel_w = torch .einsum ("bhwc,wkc->bhwk" , reshaped_query , relative_position_width )
945- rel_h = rel_h .unsqueeze (- 1 )
946- rel_w = rel_w .unsqueeze (- 2 )
947- rel_h = rel_h .reshape (batch_size , query_height * query_width , key_height , 1 )
948- rel_w = rel_w .reshape (batch_size , query_height * query_width , 1 , key_width )
949-
950- return rel_h , rel_w
951-
952911 def forward (self , hidden_states : torch .Tensor , output_attentions = False ) -> torch .Tensor :
912+ if output_attentions :
913+ logger .warning_once (
914+ "`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
915+ "`output_attentions=True`. Falling back to the manual attention implementation, but "
916+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
917+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
918+ )
919+ return super ().forward (
920+ hidden_states = hidden_states ,
921+ output_attentions = output_attentions ,
922+ )
923+
953924 batch_size , height , width , _ = hidden_states .shape
954925 # qkv with shape (3, B, nHead, H * W, C)
955926 qkv = (
@@ -960,25 +931,21 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
960931 # q, k, v with shape (B * nHead, H * W, C)
961932 query , key , value = qkv .reshape (3 , batch_size * self .num_attention_heads , height * width , - 1 ).unbind (0 )
962933
963- rel_h , rel_w = None , None
934+ attn_bias = None
964935 if self .use_rel_pos :
965- rel_h , rel_w = self .add_decomposed_rel_pos (
936+ decomposed_rel_pos = self .get_decomposed_rel_pos (
966937 query , self .rel_pos_h , self .rel_pos_w , (height , width ), (height , width )
967938 )
939+ decomposed_rel_pos = decomposed_rel_pos .reshape (
940+ batch_size , self .num_attention_heads , height * width , height * width
941+ )
942+ attn_bias = decomposed_rel_pos
968943
969944 query = query .view (batch_size , self .num_attention_heads , height * width , - 1 )
970945 key = key .view (batch_size , self .num_attention_heads , height * width , - 1 )
971946 value = value .view (batch_size , self .num_attention_heads , height * width , - 1 )
972947
973- if self .use_rel_pos :
974- rel_h = rel_h .view (batch_size , self .num_attention_heads , rel_h .size (1 ), rel_h .size (2 ), rel_h .size (3 ))
975- rel_w = rel_w .view (batch_size , self .num_attention_heads , rel_w .size (1 ), rel_w .size (2 ), rel_w .size (3 ))
976- attn_bias = (rel_h + rel_w ).view (
977- batch_size , self .num_attention_heads , rel_h .size (2 ), rel_h .size (3 ) * rel_w .size (4 )
978- )
979- attn_output = torch .nn .functional .scaled_dot_product_attention (query , key , value , attn_mask = attn_bias )
980- else :
981- attn_output = torch .nn .functional .scaled_dot_product_attention (query , key , value )
948+ attn_output = torch .nn .functional .scaled_dot_product_attention (query , key , value , attn_mask = attn_bias )
982949
983950 attn_output = (
984951 attn_output .view (batch_size , self .num_attention_heads , height , width , - 1 )
@@ -988,17 +955,7 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
988955
989956 attn_output = self .proj (attn_output )
990957
991- if output_attentions :
992- # For output_attentions, calculate the attention weights
993- attn_weights = (query @ key .transpose (- 2 , - 1 )) * self .scale
994- if attn_bias is not None :
995- attn_weights = attn_weights + attn_bias
996- attn_weights = F .softmax (attn_weights , dim = - 1 )
997- outputs = (attn_output , attn_weights )
998- else :
999- outputs = (attn_output , None )
1000-
1001- return outputs
958+ return attn_output , None
1002959
1003960
1004961SAM_VISION_ATTENTION_CLASSES = {
0 commit comments