Skip to content

Commit c53d53d

Browse files
authored
🚨🚨🚨 Fix sdpa in SAM and refactor relative position embeddings (#36422)
* fall back to eager if output_attentions * improve relative position embeddings * run modular on got_ocr2 * run-slow: sam * fix run-length encoding * fix tf processor errors * update tf_sam * fix compile error * re-run tests
1 parent fc8764c commit c53d53d

File tree

6 files changed

+62
-103
lines changed

6 files changed

+62
-103
lines changed

src/transformers/models/got_ocr2/modeling_got_ocr2.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,8 @@ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.
114114

115115
return rel_pos_resized[relative_coords.long()]
116116

117-
def add_decomposed_rel_pos(
117+
def get_decomposed_rel_pos(
118118
self,
119-
attn: torch.Tensor,
120119
query: torch.Tensor,
121120
rel_pos_h: torch.Tensor,
122121
rel_pos_w: torch.Tensor,
@@ -128,8 +127,6 @@ def add_decomposed_rel_pos(
128127
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
129128
130129
Args:
131-
attn (`torch.Tensor`):
132-
attention map.
133130
query (`torch.Tensor`):
134131
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
135132
rel_pos_h (`torch.Tensor`):
@@ -142,8 +139,8 @@ def add_decomposed_rel_pos(
142139
spatial sequence size of key k with (key_height, key_width).
143140
144141
Returns:
145-
attn (`torch.Tensor`):
146-
attention map with added relative positional embeddings.
142+
decomposed_rel_pos (`torch.Tensor`):
143+
decomposed relative position embeddings.
147144
"""
148145
query_height, query_width = q_size
149146
key_height, key_width = k_size
@@ -154,10 +151,10 @@ def add_decomposed_rel_pos(
154151
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
155152
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
156153
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
157-
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
158-
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
159-
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
160-
return attn
154+
155+
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
156+
157+
return decomposed_rel_pos
161158

162159
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
163160
batch_size, height, width, _ = hidden_states.shape
@@ -173,9 +170,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
173170
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
174171

175172
if self.use_rel_pos:
176-
attn_weights = self.add_decomposed_rel_pos(
177-
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
173+
decomposed_rel_pos = self.get_decomposed_rel_pos(
174+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
178175
)
176+
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
177+
attn_weights = attn_weights + decomposed_rel_pos
179178

180179
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
181180

src/transformers/models/sam/image_processing_sam.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ def _mask_to_rle_pytorch(input_mask: "torch.Tensor"):
13811381
continue
13821382
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
13831383
counts = [] if input_mask[i, 0] == 0 else [0]
1384-
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
1384+
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1].item()]
13851385
out.append({"size": [height, width], "counts": counts})
13861386
return out
13871387

@@ -1401,7 +1401,7 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
14011401
# Encode run length
14021402
out = []
14031403
for i in range(batch_size):
1404-
cur_idxs = change_indices[change_indices[:, 0] == i, 1] + 1
1404+
cur_idxs = change_indices[change_indices[:, 0] == i][:, 1] + 1
14051405
if len(cur_idxs) == 0:
14061406
# No changes => either all 0 or all 1
14071407
# If the entire mask is 0, RLE is [height*width] or if the entire mask is 1, RLE is [0, height*width].
@@ -1412,7 +1412,9 @@ def _mask_to_rle_tf(input_mask: "tf.Tensor"):
14121412
continue
14131413
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
14141414
counts = [] if input_mask[i, 0] == 0 else [0]
1415-
counts += [cur_idxs[0].item()] + btw_idxs.tolist() + [height * width - cur_idxs[-1]]
1415+
counts += (
1416+
[cur_idxs[0].numpy().item()] + btw_idxs.numpy().tolist() + [height * width - cur_idxs[-1].numpy().item()]
1417+
)
14161418
out.append({"size": [height, width], "counts": counts})
14171419
return out
14181420

src/transformers/models/sam/modeling_sam.py

Lines changed: 31 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -820,9 +820,8 @@ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.
820820

821821
return rel_pos_resized[relative_coords.long()]
822822

823-
def add_decomposed_rel_pos(
823+
def get_decomposed_rel_pos(
824824
self,
825-
attn: torch.Tensor,
826825
query: torch.Tensor,
827826
rel_pos_h: torch.Tensor,
828827
rel_pos_w: torch.Tensor,
@@ -834,8 +833,6 @@ def add_decomposed_rel_pos(
834833
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
835834
836835
Args:
837-
attn (`torch.Tensor`):
838-
attention map.
839836
query (`torch.Tensor`):
840837
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
841838
rel_pos_h (`torch.Tensor`):
@@ -848,8 +845,8 @@ def add_decomposed_rel_pos(
848845
spatial sequence size of key k with (key_height, key_width).
849846
850847
Returns:
851-
attn (`torch.Tensor`):
852-
attention map with added relative positional embeddings.
848+
decomposed_rel_pos (`torch.Tensor`):
849+
decomposed relative position embeddings.
853850
"""
854851
query_height, query_width = q_size
855852
key_height, key_width = k_size
@@ -860,10 +857,10 @@ def add_decomposed_rel_pos(
860857
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
861858
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
862859
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
863-
attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
864-
attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
865-
attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
866-
return attn
860+
861+
decomposed_rel_pos = rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
862+
863+
return decomposed_rel_pos
867864

868865
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
869866
batch_size, height, width, _ = hidden_states.shape
@@ -879,9 +876,11 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
879876
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
880877

881878
if self.use_rel_pos:
882-
attn_weights = self.add_decomposed_rel_pos(
883-
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
879+
decomposed_rel_pos = self.get_decomposed_rel_pos(
880+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
884881
)
882+
decomposed_rel_pos = decomposed_rel_pos.reshape_as(attn_weights)
883+
attn_weights = attn_weights + decomposed_rel_pos
885884

886885
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
887886

@@ -909,47 +908,19 @@ class SamVisionSdpaAttention(SamVisionAttention):
909908
def __init__(self, config, window_size):
910909
super().__init__(config, window_size)
911910

912-
def add_decomposed_rel_pos(
913-
self,
914-
query: torch.Tensor,
915-
rel_pos_h: torch.Tensor,
916-
rel_pos_w: torch.Tensor,
917-
q_size: Tuple[int, int],
918-
k_size: Tuple[int, int],
919-
) -> torch.Tensor:
920-
"""
921-
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
922-
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
923-
This method is reimplemented to follow the implementation in:
924-
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/modeling/image_encoder.py # noqa B950
925-
This implementation is more memory efficient when using SDPA in the forward method.
926-
Args:
927-
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
928-
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
929-
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
930-
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
931-
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
932-
933-
Returns:
934-
attn (Tensor): attention map with added relative positional embeddings.
935-
"""
936-
query_height, query_width = q_size
937-
key_height, key_width = k_size
938-
relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
939-
relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
940-
941-
batch_size, _, dim = query.shape
942-
reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
943-
rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
944-
rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
945-
rel_h = rel_h.unsqueeze(-1)
946-
rel_w = rel_w.unsqueeze(-2)
947-
rel_h = rel_h.reshape(batch_size, query_height * query_width, key_height, 1)
948-
rel_w = rel_w.reshape(batch_size, query_height * query_width, 1, key_width)
949-
950-
return rel_h, rel_w
951-
952911
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
912+
if output_attentions:
913+
logger.warning_once(
914+
"`SamVisionSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
915+
"`output_attentions=True`. Falling back to the manual attention implementation, but "
916+
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
917+
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
918+
)
919+
return super().forward(
920+
hidden_states=hidden_states,
921+
output_attentions=output_attentions,
922+
)
923+
953924
batch_size, height, width, _ = hidden_states.shape
954925
# qkv with shape (3, B, nHead, H * W, C)
955926
qkv = (
@@ -960,25 +931,21 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
960931
# q, k, v with shape (B * nHead, H * W, C)
961932
query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
962933

963-
rel_h, rel_w = None, None
934+
attn_bias = None
964935
if self.use_rel_pos:
965-
rel_h, rel_w = self.add_decomposed_rel_pos(
936+
decomposed_rel_pos = self.get_decomposed_rel_pos(
966937
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
967938
)
939+
decomposed_rel_pos = decomposed_rel_pos.reshape(
940+
batch_size, self.num_attention_heads, height * width, height * width
941+
)
942+
attn_bias = decomposed_rel_pos
968943

969944
query = query.view(batch_size, self.num_attention_heads, height * width, -1)
970945
key = key.view(batch_size, self.num_attention_heads, height * width, -1)
971946
value = value.view(batch_size, self.num_attention_heads, height * width, -1)
972947

973-
if self.use_rel_pos:
974-
rel_h = rel_h.view(batch_size, self.num_attention_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
975-
rel_w = rel_w.view(batch_size, self.num_attention_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
976-
attn_bias = (rel_h + rel_w).view(
977-
batch_size, self.num_attention_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
978-
)
979-
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
980-
else:
981-
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value)
948+
attn_output = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attn_bias)
982949

983950
attn_output = (
984951
attn_output.view(batch_size, self.num_attention_heads, height, width, -1)
@@ -988,17 +955,7 @@ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch
988955

989956
attn_output = self.proj(attn_output)
990957

991-
if output_attentions:
992-
# For output_attentions, calculate the attention weights
993-
attn_weights = (query @ key.transpose(-2, -1)) * self.scale
994-
if attn_bias is not None:
995-
attn_weights = attn_weights + attn_bias
996-
attn_weights = F.softmax(attn_weights, dim=-1)
997-
outputs = (attn_output, attn_weights)
998-
else:
999-
outputs = (attn_output, None)
1000-
1001-
return outputs
958+
return attn_output, None
1002959

1003960

1004961
SAM_VISION_ATTENTION_CLASSES = {

src/transformers/models/sam/modeling_tf_sam.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -982,9 +982,8 @@ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor
982982

983983
return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
984984

985-
def add_decomposed_rel_pos(
985+
def get_decomposed_rel_pos(
986986
self,
987-
attn: tf.Tensor,
988987
query: tf.Tensor,
989988
rel_pos_h: tf.Tensor,
990989
rel_pos_w: tf.Tensor,
@@ -996,8 +995,6 @@ def add_decomposed_rel_pos(
996995
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
997996
998997
Args:
999-
attn (`tf.Tensor`):
1000-
attention map.
1001998
query (`tf.Tensor`):
1002999
query q in the attention layer with shape (batch_size, query_height * query_width, channel).
10031000
rel_pos_h (`tf.Tensor`):
@@ -1010,8 +1007,8 @@ def add_decomposed_rel_pos(
10101007
spatial sequence size of key k with (key_height, key_width).
10111008
10121009
Returns:
1013-
attn (`tf.Tensor`):
1014-
attention map with added relative positional embeddings.
1010+
decomposed_rel_pos (`torch.Tensor`):
1011+
decomposed relative position embeddings.
10151012
"""
10161013
query_height, query_width = q_size
10171014
key_height, key_width = k_size
@@ -1022,10 +1019,12 @@ def add_decomposed_rel_pos(
10221019
reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
10231020
rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
10241021
rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
1025-
attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
1026-
attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
1027-
attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
1028-
return attn
1022+
1023+
rel_h = tf.expand_dims(rel_h, axis=-1)
1024+
rel_w = tf.expand_dims(rel_w, axis=-2)
1025+
decomposed_rel_pos = rel_h + rel_w
1026+
1027+
return decomposed_rel_pos
10291028

10301029
def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
10311030
batch_size, height, width, _ = shape_list(hidden_states)
@@ -1039,9 +1038,11 @@ def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False
10391038
attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
10401039

10411040
if self.use_rel_pos:
1042-
attn_weights = self.add_decomposed_rel_pos(
1043-
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
1041+
decomposed_rel_pos = self.get_decomposed_rel_pos(
1042+
query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
10441043
)
1044+
decomposed_rel_pos = tf.reshape(decomposed_rel_pos, shape_list(attn_weights))
1045+
attn_weights = attn_weights + decomposed_rel_pos
10451046

10461047
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
10471048

src/transformers/processing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg
979979
kwarg_value = kwargs.get(modality_key, "__empty__")
980980
else:
981981
kwarg_value = "__empty__"
982-
if kwarg_value != "__empty__":
982+
if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
983983
output_kwargs[modality][modality_key] = kwarg_value
984984
used_keys.add(modality_key)
985985

tests/models/sam/test_processor_sam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def test_rle_encoding(self):
312312
# This is shape (1, 2, 2).
313313
# Flattened in Fortran order -> [0, 1, 1, 1].
314314
# The RLE for [0,1,1,1] is [1, 3].
315-
input_mask = tf.tensor([[[0, 1], [1, 1]]], dtype=tf.int64)
315+
input_mask = tf.constant([[[0, 1], [1, 1]]], dtype=tf.int64)
316316
rle = _mask_to_rle_tf(input_mask)
317317

318318
self.assertEqual(len(rle), 1)

0 commit comments

Comments
 (0)