55import torch
66from torch import nn
77import torch .nn .functional as F
8- from transformers import PreTrainedModel
98from transformers .modeling_outputs import CausalLMOutputWithPast
109from transformers .cache_utils import Cache
1110from transformers .modeling_attn_mask_utils import AttentionMaskConverter
1211
12+ from surya .common .pretrained import SuryaPreTrainedModel
1313from surya .common .s3 import S3DownloaderMixin
1414from surya .common .surya .config import SuryaModelConfig
1515from surya .common .surya .decoder import SuryaDecoderModel
@@ -56,6 +56,7 @@ class FlashAttentionKwargs(TypedDict, total=False):
5656
5757class KwargsForCausalLM (FlashAttentionKwargs ): ...
5858
59+
5960class DistanceProjection (nn .Module ):
6061 def __init__ (self , in_features : int , out_features : int ):
6162 super ().__init__ ()
@@ -75,7 +76,8 @@ def init_weights(self):
7576 nn .init .zeros_ (self .fc1 .bias )
7677 nn .init .zeros_ (self .fc2 .bias )
7778
78- class SuryaModel (S3DownloaderMixin , PreTrainedModel ):
79+
80+ class SuryaModel (S3DownloaderMixin , SuryaPreTrainedModel ):
7981 config_class = SuryaModelConfig
8082 supports_gradient_checkpointing = True
8183 _skip_keys_device_placement = ["past_key_values" ]
@@ -95,8 +97,9 @@ def __init__(
9597 embedder : SimpleTokenEmbedder = None ,
9698 vision_encoder : SuryaEncoderModel = None ,
9799 decoder : SuryaDecoderModel = None ,
100+ ** kwargs ,
98101 ):
99- super ().__init__ (config )
102+ super ().__init__ (config , ** kwargs )
100103
101104 if vision_encoder is None :
102105 vision_encoder = SuryaEncoderModel (config .vision_encoder )
@@ -166,29 +169,30 @@ def maybe_static_pad_image_inputs(
166169 chunk_pixels : torch .Tensor ,
167170 chunk_grid_thw : torch .Tensor ,
168171 actual_chunk_len : int ,
169- encoder_chunk_size : int
172+ encoder_chunk_size : int ,
170173 ) -> Tuple [torch .Tensor , torch .Tensor ]:
171- valid_embed_len = actual_chunk_len // (self .vision_encoder .spatial_merge_size ** 2 )
174+ valid_embed_len = actual_chunk_len // (
175+ self .vision_encoder .spatial_merge_size ** 2
176+ )
172177 if settings .FOUNDATION_STATIC_CACHE and actual_chunk_len < encoder_chunk_size :
173178 padding_len = encoder_chunk_size - actual_chunk_len
174179 padding = torch .zeros (
175- padding_len ,
180+ padding_len ,
176181 * chunk_pixels .shape [1 :],
177182 device = chunk_pixels .device ,
178- dtype = chunk_pixels .dtype
183+ dtype = chunk_pixels .dtype ,
179184 )
180185 chunk_pixels = torch .cat ([chunk_pixels , padding ], dim = 0 )
181-
186+
182187 padding_grid = torch .tensor (
183188 [[1 , 2 , padding_len // 2 ]],
184189 device = chunk_grid_thw .device ,
185- dtype = chunk_grid_thw .dtype
190+ dtype = chunk_grid_thw .dtype ,
186191 )
187192 chunk_grid_thw = torch .cat ([chunk_grid_thw , padding_grid ], dim = 0 )
188193
189194 return chunk_pixels , chunk_grid_thw , valid_embed_len
190195
191-
192196 def get_image_embeddings (
193197 self ,
194198 pixel_values : torch .Tensor ,
@@ -225,15 +229,18 @@ def get_image_embeddings(
225229 end = chunks [i + 1 ]
226230 grid_start = grid_chunks [i ]
227231 grid_end = grid_chunks [i + 1 ]
228-
232+
229233 chunk_pixels = pixel_values [start :end ]
230234 chunk_grid_thw = grid_thw [grid_start :grid_end ]
231235 actual_chunk_len = end - start
232- chunk_pixels , chunk_grid_thw , valid_embed_len = self .maybe_static_pad_image_inputs (chunk_pixels , chunk_grid_thw , actual_chunk_len , encoder_chunk_size )
236+ chunk_pixels , chunk_grid_thw , valid_embed_len = (
237+ self .maybe_static_pad_image_inputs (
238+ chunk_pixels , chunk_grid_thw , actual_chunk_len , encoder_chunk_size
239+ )
240+ )
233241
234242 chunk_embeddings = self .vision_encoder .embed_images (
235- image_batch = chunk_pixels ,
236- grid_thw = chunk_grid_thw
243+ image_batch = chunk_pixels , grid_thw = chunk_grid_thw
237244 )
238245 embeddings .append (chunk_embeddings [:valid_embed_len ])
239246
@@ -340,28 +347,30 @@ def get_2d_learned_embeddings(
340347 ) # Shape is num_image_tokens x embed_dim
341348
342349 def get_logits (self , hidden_states ):
343- assert hidden_states .shape [1 ] == 1 , "Multi output predictions only applied on the last token"
350+ assert hidden_states .shape [1 ] == 1 , (
351+ "Multi output predictions only applied on the last token"
352+ )
344353
345354 all_lm_logits = []
346355 all_bbox_logits = []
347-
356+
348357 current_hidden = hidden_states
349-
358+
350359 # Loop includes initial prediction (i=0) plus multi_out_distance additional predictions
351360 for i in range (self .config .multi_output_distance + 1 ):
352361 if i > 0 :
353- current_hidden = self .multi_output_projections [i - 1 ](current_hidden )
354-
362+ current_hidden = self .multi_output_projections [i - 1 ](current_hidden )
363+
355364 lm_logits = self .lm_head (current_hidden )
356365 bbox_logits = F .sigmoid (self .bbox_head (current_hidden ))
357-
366+
358367 all_lm_logits .append (lm_logits )
359368 all_bbox_logits .append (bbox_logits )
360-
369+
361370 # Concatenate along sequence dimension (dim=1)
362371 final_lm_logits = torch .cat (all_lm_logits , dim = 1 )
363372 final_bbox_logits = torch .cat (all_bbox_logits , dim = 1 )
364-
373+
365374 return final_lm_logits , final_bbox_logits
366375
367376 def forward (
@@ -387,24 +396,25 @@ def forward(
387396 ** kwargs : KwargsForCausalLM ,
388397 ):
389398 # Process the mixed batch if provided
390- if any ([
391- input_ids is None ,
392- (prefill and (image_tiles is None or grid_thw is None )),
393- position_ids is None ,
394- cache_position is None
395- ]):
396- raise ValueError ("`input_ids`, `position_ids`, and `cache_position` **must** be specified. `image_tiles` and `grid_thw` are required for prefill" )
399+ if any (
400+ [
401+ input_ids is None ,
402+ (prefill and (image_tiles is None or grid_thw is None )),
403+ position_ids is None ,
404+ cache_position is None ,
405+ ]
406+ ):
407+ raise ValueError (
408+ "`input_ids`, `position_ids`, and `cache_position` **must** be specified. `image_tiles` and `grid_thw` are required for prefill"
409+ )
397410
398411 inputs_embeds = self .embed_ids_boxes_images (
399412 input_ids , image_tiles , grid_thw , encoder_chunk_size
400413 )
401414
402415 # Handling flash attention kwargs outside the decoder to speed up + avoid graph breaks inside the decoder
403416 # Skipped during decoding since not required
404- if (
405- self .decoder .config ._attn_implementation == "flash_attention_2"
406- and prefill
407- ):
417+ if self .decoder .config ._attn_implementation == "flash_attention_2" and prefill :
408418 batch_size , query_length , _ = inputs_embeds .shape
409419 indices_k , cu_seqlens_k , max_seqlen_in_batch_k = _get_unpad_data (
410420 attention_mask
@@ -451,7 +461,9 @@ def forward(
451461 bbox_logits = None
452462 vocab_size = lm_logits .shape [- 1 ]
453463 labels = torch .roll (labels , shifts = - 1 , dims = - 1 )
454- loss = F .cross_entropy (lm_logits .view (- 1 , vocab_size ), labels .view (- 1 ), reduction = "mean" )
464+ loss = F .cross_entropy (
465+ lm_logits .view (- 1 , vocab_size ), labels .view (- 1 ), reduction = "mean"
466+ )
455467 else :
456468 lm_logits , bbox_logits = self .get_logits (hidden_states )
457469
@@ -561,9 +573,15 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
561573 device = device ,
562574 )
563575 # Batch-aware diagonal attend mask
564- diagonal_attend_mask = torch .arange (target_length , device = device ).unsqueeze (0 ) > cache_position .unsqueeze (- 1 )
565- causal_mask = causal_mask .unsqueeze (0 ) * diagonal_attend_mask # (batch_size, seq_len, target_len)
566- causal_mask = causal_mask [:, None , :, :] # (batch_size, 1, seq_len, target_len)
576+ diagonal_attend_mask = torch .arange (target_length , device = device ).unsqueeze (
577+ 0
578+ ) > cache_position .unsqueeze (- 1 )
579+ causal_mask = (
580+ causal_mask .unsqueeze (0 ) * diagonal_attend_mask
581+ ) # (batch_size, seq_len, target_len)
582+ causal_mask = causal_mask [
583+ :, None , :, :
584+ ] # (batch_size, 1, seq_len, target_len)
567585 if attention_mask is not None :
568586 causal_mask = (
569587 causal_mask .clone ()
@@ -578,4 +596,4 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
578596 causal_mask [:, :, :, :mask_length ] = causal_mask [
579597 :, :, :, :mask_length
580598 ].masked_fill (padding_mask , min_dtype )
581- return causal_mask
599+ return causal_mask
0 commit comments