Skip to content

Commit a7ffa7e

Browse files
authored
Merge pull request #457 from datalab-to/dev
Move flash attention funcs
2 parents a7b133e + 5234bc0 commit a7ffa7e

File tree

5 files changed

+19
-24
lines changed

5 files changed

+19
-24
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "surya-ocr"
3-
version = "0.16.6"
3+
version = "0.16.7"
44
description = "OCR, layout, reading order, and table recognition in 90+ languages"
55
authors = ["Vik Paruchuri <[email protected]>"]
66
readme = "README.md"

surya/common/surya/__init__.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@
1717
from surya.common.surya.encoder import SuryaEncoderModel
1818
from surya.settings import settings
1919

20-
from transformers.utils import is_flash_attn_2_available
21-
2220
from surya.logging import get_logger
2321

24-
if is_flash_attn_2_available():
25-
from surya.common.surya.flash_attn_utils import _get_unpad_data
26-
2722
logger = get_logger()
2823

2924

@@ -415,6 +410,9 @@ def forward(
415410
# Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder
416411
# Skipped during decoding since not required
417412
if self.decoder.config._attn_implementation == "flash_attention_2" and prefill:
413+
# Needed for CPU -> GPU
414+
from surya.common.surya.flash_attn_utils import _get_unpad_data
415+
418416
batch_size, query_length, _ = inputs_embeds.shape
419417
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
420418
attention_mask

surya/common/surya/decoder/__init__.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@
2121
from surya.common.pretrained import SuryaPreTrainedModel
2222
from surya.common.surya.decoder.config import SuryaDecoderConfig
2323

24-
from transformers.utils import is_flash_attn_2_available
25-
26-
if is_flash_attn_2_available():
27-
from surya.common.surya.flash_attn_utils import (
28-
flash_attn_decode,
29-
flash_attn_prefill,
30-
)
3124

3225
logger = logging.get_logger(__name__)
3326

@@ -206,6 +199,12 @@ def forward(
206199
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
207200
)
208201
elif self.config._attn_implementation == "flash_attention_2":
202+
# Needed for CPU -> GPU
203+
from surya.common.surya.flash_attn_utils import (
204+
flash_attn_decode,
205+
flash_attn_prefill,
206+
)
207+
209208
if prefill:
210209
attention_interface = flash_attn_prefill
211210
else:

surya/common/surya/encoder/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,11 @@
55
import torch.nn as nn
66
import torch.nn.functional as F
77
from transformers.activations import ACT2FN
8-
from transformers.utils import is_flash_attn_2_available
98

109
from surya.common.pretrained import SuryaPreTrainedModel
1110
from surya.common.surya.encoder.config import SuryaEncoderConfig
1211
from surya.logging import get_logger
1312

14-
if is_flash_attn_2_available():
15-
from flash_attn import flash_attn_varlen_func
16-
from flash_attn.layers.rotary import apply_rotary_emb # noqa
17-
18-
1913
logger = get_logger()
2014

2115

@@ -127,6 +121,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
127121
def apply_rotary_pos_emb_flashatt(
128122
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
129123
) -> Tuple[torch.Tensor, torch.Tensor]:
124+
from flash_attn.layers.rotary import apply_rotary_emb
125+
130126
cos = cos.chunk(2, dim=-1)[0].contiguous()
131127
sin = sin.chunk(2, dim=-1)[0].contiguous()
132128
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
@@ -148,6 +144,8 @@ def forward(
148144
rotary_pos_emb: Optional[torch.Tensor] = None,
149145
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
150146
) -> torch.Tensor:
147+
from flash_attn import flash_attn_varlen_func
148+
151149
bsz = hidden_states.shape[0]
152150
seq_length = hidden_states.shape[1]
153151
q, k, v = (

surya/ocr_error/model/encoder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,11 @@
1616
)
1717

1818
from transformers.utils import (
19-
is_flash_attn_2_available,
2019
is_flash_attn_greater_or_equal_2_10,
2120
)
2221

2322
from surya.common.pretrained import SuryaPreTrainedModel
2423

25-
if is_flash_attn_2_available():
26-
from flash_attn import flash_attn_func, flash_attn_varlen_func
27-
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
28-
2924
from surya.common.s3 import S3DownloaderMixin
3025
from surya.ocr_error.model.config import DistilBertConfig
3126

@@ -342,6 +337,9 @@ def _flash_attention_forward(
342337
softmax_scale (`float`, *optional*):
343338
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
344339
"""
340+
from flash_attn import flash_attn_func, flash_attn_varlen_func
341+
from flash_attn.bert_padding import pad_input
342+
345343
if not self._flash_attn_uses_top_left_mask:
346344
causal = self.is_causal
347345
else:
@@ -397,6 +395,8 @@ def _flash_attention_forward(
397395
def _upad_input(
398396
self, query_layer, key_layer, value_layer, attention_mask, query_length
399397
):
398+
from flash_attn.bert_padding import index_first_axis, unpad_input
399+
400400
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
401401
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
402402

0 commit comments

Comments
 (0)