Skip to content

Commit a7b133e

Browse files
authored
Merge pull request #456 from datalab-to/dev
Enable setting attention method
2 parents b0b1817 + ebf5ec7 commit a7b133e

File tree

18 files changed

+766
-556
lines changed

18 files changed

+766
-556
lines changed

poetry.lock

Lines changed: 42 additions & 40 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.16.5"
3+
version = "0.16.6"
44
description = "OCR, layout, reading order, and table recognition in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"
@@ -13,7 +13,7 @@ packages = [
1313

1414
[tool.poetry.dependencies]
1515
python = "^3.10"
16-
transformers = ">=4.51.2,<4.54.0"
16+
transformers = ">=4.56.1"
1717
torch = "^2.7.0"
1818
pydantic = "^2.5.3"
1919
pydantic-settings = "^2.1.0"

surya/common/adetr/decoder.py

Lines changed: 239 additions & 77 deletions
Large diffs are not rendered by default.

surya/common/donut/encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch import nn
99

1010
from transformers.activations import ACT2FN
11-
from transformers.modeling_utils import PreTrainedModel
1211
from transformers.pytorch_utils import (
1312
find_pruneable_heads_and_indices,
1413
meshgrid,
@@ -17,6 +16,7 @@
1716
from transformers.utils import ModelOutput
1817
from transformers import DonutSwinConfig
1918

19+
from surya.common.pretrained import SuryaPreTrainedModel
2020
from surya.common.util import mark_step
2121

2222
_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024]
@@ -932,7 +932,7 @@ def forward(
932932

933933

934934
# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin
935-
class DonutSwinPreTrainedModel(PreTrainedModel):
935+
class DonutSwinPreTrainedModel(SuryaPreTrainedModel):
936936
"""
937937
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
938938
models.

surya/common/pretrained.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Optional
2+
3+
from transformers import PreTrainedModel
4+
from transformers.utils import is_flash_attn_2_available
5+
6+
7+
class SuryaPreTrainedModel(PreTrainedModel):
8+
# No-op if we pass attention, so we can set attention however we want in the config
9+
def _check_and_adjust_attn_implementation(
10+
self, attn_implementation: Optional[str], **kwargs
11+
):
12+
if attn_implementation is None:
13+
try:
14+
self._sdpa_can_dispatch(True)
15+
attn_implementation = "sdpa"
16+
except (ValueError, ImportError):
17+
attn_implementation = "eager"
18+
19+
if self._supports_flash_attn and is_flash_attn_2_available():
20+
attn_implementation = "flash_attention_2"
21+
22+
return attn_implementation

surya/common/surya/__init__.py

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import torch
66
from torch import nn
77
import torch.nn.functional as F
8-
from transformers import PreTrainedModel
98
from transformers.modeling_outputs import CausalLMOutputWithPast
109
from transformers.cache_utils import Cache
1110
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
1211

12+
from surya.common.pretrained import SuryaPreTrainedModel
1313
from surya.common.s3 import S3DownloaderMixin
1414
from surya.common.surya.config import SuryaModelConfig
1515
from surya.common.surya.decoder import SuryaDecoderModel
@@ -56,6 +56,7 @@ class FlashAttentionKwargs(TypedDict, total=False):
5656

5757
class KwargsForCausalLM(FlashAttentionKwargs): ...
5858

59+
5960
class DistanceProjection(nn.Module):
6061
def __init__(self, in_features: int, out_features: int):
6162
super().__init__()
@@ -75,7 +76,8 @@ def init_weights(self):
7576
nn.init.zeros_(self.fc1.bias)
7677
nn.init.zeros_(self.fc2.bias)
7778

78-
class SuryaModel(S3DownloaderMixin, PreTrainedModel):
79+
80+
class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
7981
config_class = SuryaModelConfig
8082
supports_gradient_checkpointing = True
8183
_skip_keys_device_placement = ["past_key_values"]
@@ -95,8 +97,9 @@ def __init__(
9597
embedder: SimpleTokenEmbedder = None,
9698
vision_encoder: SuryaEncoderModel = None,
9799
decoder: SuryaDecoderModel = None,
100+
**kwargs,
98101
):
99-
super().__init__(config)
102+
super().__init__(config, **kwargs)
100103

101104
if vision_encoder is None:
102105
vision_encoder = SuryaEncoderModel(config.vision_encoder)
@@ -166,29 +169,30 @@ def maybe_static_pad_image_inputs(
166169
chunk_pixels: torch.Tensor,
167170
chunk_grid_thw: torch.Tensor,
168171
actual_chunk_len: int,
169-
encoder_chunk_size: int
172+
encoder_chunk_size: int,
170173
) -> Tuple[torch.Tensor, torch.Tensor]:
171-
valid_embed_len = actual_chunk_len // (self.vision_encoder.spatial_merge_size ** 2)
174+
valid_embed_len = actual_chunk_len // (
175+
self.vision_encoder.spatial_merge_size**2
176+
)
172177
if settings.FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size:
173178
padding_len = encoder_chunk_size - actual_chunk_len
174179
padding = torch.zeros(
175-
padding_len,
180+
padding_len,
176181
*chunk_pixels.shape[1:],
177182
device=chunk_pixels.device,
178-
dtype=chunk_pixels.dtype
183+
dtype=chunk_pixels.dtype,
179184
)
180185
chunk_pixels = torch.cat([chunk_pixels, padding], dim=0)
181-
186+
182187
padding_grid = torch.tensor(
183188
[[1, 2, padding_len // 2]],
184189
device=chunk_grid_thw.device,
185-
dtype=chunk_grid_thw.dtype
190+
dtype=chunk_grid_thw.dtype,
186191
)
187192
chunk_grid_thw = torch.cat([chunk_grid_thw, padding_grid], dim=0)
188193

189194
return chunk_pixels, chunk_grid_thw, valid_embed_len
190195

191-
192196
def get_image_embeddings(
193197
self,
194198
pixel_values: torch.Tensor,
@@ -225,15 +229,18 @@ def get_image_embeddings(
225229
end = chunks[i + 1]
226230
grid_start = grid_chunks[i]
227231
grid_end = grid_chunks[i + 1]
228-
232+
229233
chunk_pixels = pixel_values[start:end]
230234
chunk_grid_thw = grid_thw[grid_start:grid_end]
231235
actual_chunk_len = end - start
232-
chunk_pixels, chunk_grid_thw, valid_embed_len = self.maybe_static_pad_image_inputs(chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size)
236+
chunk_pixels, chunk_grid_thw, valid_embed_len = (
237+
self.maybe_static_pad_image_inputs(
238+
chunk_pixels, chunk_grid_thw, actual_chunk_len, encoder_chunk_size
239+
)
240+
)
233241

234242
chunk_embeddings = self.vision_encoder.embed_images(
235-
image_batch=chunk_pixels,
236-
grid_thw=chunk_grid_thw
243+
image_batch=chunk_pixels, grid_thw=chunk_grid_thw
237244
)
238245
embeddings.append(chunk_embeddings[:valid_embed_len])
239246

@@ -340,28 +347,30 @@ def get_2d_learned_embeddings(
340347
) # Shape is num_image_tokens x embed_dim
341348

342349
def get_logits(self, hidden_states):
343-
assert hidden_states.shape[1] == 1, "Multi output predictions only applied on the last token"
350+
assert hidden_states.shape[1] == 1, (
351+
"Multi output predictions only applied on the last token"
352+
)
344353

345354
all_lm_logits = []
346355
all_bbox_logits = []
347-
356+
348357
current_hidden = hidden_states
349-
358+
350359
# Loop includes initial prediction (i=0) plus multi_out_distance additional predictions
351360
for i in range(self.config.multi_output_distance + 1):
352361
if i > 0:
353-
current_hidden = self.multi_output_projections[i-1](current_hidden)
354-
362+
current_hidden = self.multi_output_projections[i - 1](current_hidden)
363+
355364
lm_logits = self.lm_head(current_hidden)
356365
bbox_logits = F.sigmoid(self.bbox_head(current_hidden))
357-
366+
358367
all_lm_logits.append(lm_logits)
359368
all_bbox_logits.append(bbox_logits)
360-
369+
361370
# Concatenate along sequence dimension (dim=1)
362371
final_lm_logits = torch.cat(all_lm_logits, dim=1)
363372
final_bbox_logits = torch.cat(all_bbox_logits, dim=1)
364-
373+
365374
return final_lm_logits, final_bbox_logits
366375

367376
def forward(
@@ -387,24 +396,25 @@ def forward(
387396
**kwargs: KwargsForCausalLM,
388397
):
389398
# Process the mixed batch if provided
390-
if any([
391-
input_ids is None,
392-
(prefill and (image_tiles is None or grid_thw is None)),
393-
position_ids is None,
394-
cache_position is None
395-
]):
396-
raise ValueError("`input_ids`, `position_ids`, and `cache_position` **must** be specified. `image_tiles` and `grid_thw` are required for prefill")
399+
if any(
400+
[
401+
input_ids is None,
402+
(prefill and (image_tiles is None or grid_thw is None)),
403+
position_ids is None,
404+
cache_position is None,
405+
]
406+
):
407+
raise ValueError(
408+
"`input_ids`, `position_ids`, and `cache_position` **must** be specified. `image_tiles` and `grid_thw` are required for prefill"
409+
)
397410

398411
inputs_embeds = self.embed_ids_boxes_images(
399412
input_ids, image_tiles, grid_thw, encoder_chunk_size
400413
)
401414

402415
# Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder
403416
# Skipped during decoding since not required
404-
if (
405-
self.decoder.config._attn_implementation == "flash_attention_2"
406-
and prefill
407-
):
417+
if self.decoder.config._attn_implementation == "flash_attention_2" and prefill:
408418
batch_size, query_length, _ = inputs_embeds.shape
409419
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
410420
attention_mask
@@ -451,7 +461,9 @@ def forward(
451461
bbox_logits = None
452462
vocab_size = lm_logits.shape[-1]
453463
labels = torch.roll(labels, shifts=-1, dims=-1)
454-
loss = F.cross_entropy(lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean")
464+
loss = F.cross_entropy(
465+
lm_logits.view(-1, vocab_size), labels.view(-1), reduction="mean"
466+
)
455467
else:
456468
lm_logits, bbox_logits = self.get_logits(hidden_states)
457469

@@ -561,9 +573,15 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
561573
device=device,
562574
)
563575
# Batch-aware diagonal attend mask
564-
diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze(0) > cache_position.unsqueeze(-1)
565-
causal_mask = causal_mask.unsqueeze(0) * diagonal_attend_mask # (batch_size, seq_len, target_len)
566-
causal_mask = causal_mask[:, None, :, :] # (batch_size, 1, seq_len, target_len)
576+
diagonal_attend_mask = torch.arange(target_length, device=device).unsqueeze(
577+
0
578+
) > cache_position.unsqueeze(-1)
579+
causal_mask = (
580+
causal_mask.unsqueeze(0) * diagonal_attend_mask
581+
) # (batch_size, seq_len, target_len)
582+
causal_mask = causal_mask[
583+
:, None, :, :
584+
] # (batch_size, 1, seq_len, target_len)
567585
if attention_mask is not None:
568586
causal_mask = (
569587
causal_mask.clone()
@@ -578,4 +596,4 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
578596
causal_mask[:, :, :, :mask_length] = causal_mask[
579597
:, :, :, :mask_length
580598
].masked_fill(padding_mask, min_dtype)
581-
return causal_mask
599+
return causal_mask

surya/common/surya/decoder/__init__.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
BaseModelOutputWithPast,
1313
)
1414
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
15-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
15+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
1616
from transformers.processing_utils import Unpack
1717
from transformers.utils import (
1818
logging,
1919
)
20+
21+
from surya.common.pretrained import SuryaPreTrainedModel
2022
from surya.common.surya.decoder.config import SuryaDecoderConfig
2123

2224
from transformers.utils import is_flash_attn_2_available
@@ -180,15 +182,15 @@ def forward(
180182

181183
if past_key_value is not None:
182184
# sin and cos are specific to RoPE models; cache_position needed for the static cache
183-
# cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism
185+
# cache_idxs, num_valid_tokens, and prefill add support for our new caching mechanism
184186
cache_kwargs = {
185187
"sin": sin,
186188
"cos": cos,
187189
"cache_position": cache_position,
188190
"cache_idxs": cache_idxs,
189191
"num_valid_tokens": num_valid_tokens,
190192
"prefill": prefill,
191-
"text_lengths": text_lengths
193+
"text_lengths": text_lengths,
192194
}
193195
key_states, value_states = past_key_value.update(
194196
key_states, value_states, self.layer_idx, cache_kwargs
@@ -406,7 +408,7 @@ def forward(self, x, position_ids):
406408
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
407409

408410

409-
class Qwen2PreTrainedModel(PreTrainedModel):
411+
class Qwen2PreTrainedModel(SuryaPreTrainedModel):
410412
config_class = SuryaDecoderConfig
411413
base_model_prefix = "model"
412414
supports_gradient_checkpointing = True
@@ -482,22 +484,18 @@ def forward(
482484
)
483485

484486
if inputs_embeds is None:
485-
raise ValueError(
486-
"You must specify inputs_embeds"
487-
)
487+
raise ValueError("You must specify inputs_embeds")
488488

489489
if cache_position is None:
490-
raise ValueError(
491-
"You must specify cache_position"
492-
)
490+
raise ValueError("You must specify cache_position")
493491

494492
if position_ids is None:
495-
raise ValueError(
496-
"You must specify position_ids"
497-
)
493+
raise ValueError("You must specify position_ids")
498494

499495
hidden_states = inputs_embeds
500-
causal_mask = attention_mask # We make the 4D mask in the combined model when needed
496+
causal_mask = (
497+
attention_mask # We make the 4D mask in the combined model when needed
498+
)
501499

502500
# create position embeddings to be shared across the decoder layers
503501
position_embeddings = self.rotary_emb(hidden_states, position_ids)
@@ -528,4 +526,4 @@ def forward(
528526
last_hidden_state=hidden_states,
529527
past_key_values=past_key_values if use_cache else None,
530528
)
531-
return output if return_dict else output.to_tuple()
529+
return output if return_dict else output.to_tuple()

surya/common/surya/encoder/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
7-
from transformers import PreTrainedModel
87
from transformers.activations import ACT2FN
98
from transformers.utils import is_flash_attn_2_available
109

10+
from surya.common.pretrained import SuryaPreTrainedModel
1111
from surya.common.surya.encoder.config import SuryaEncoderConfig
1212
from surya.logging import get_logger
1313

@@ -472,7 +472,7 @@ def forward(
472472
"""
473473

474474

475-
class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
475+
class Qwen2_5_VLPreTrainedModel(SuryaPreTrainedModel):
476476
config_class = SuryaEncoderConfig
477477
base_model_prefix = "model"
478478
supports_gradient_checkpointing = True

0 commit comments

Comments
 (0)