@@ -3361,6 +3361,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
33613361 media_offset = torch .stack ([torch .zeros (matrix .shape [0 ], dtype = torch .long ), matrix ], dim = - 1 )[None ]
33623362 inputs ['_data' ] = {'pixel_values' : image_inputs ['pixel_values' ]}
33633363 inputs ['media_offset' ] = media_offset
3364+ inputs ['num_images' ] = image_inputs ['pixel_values' ].shape [0 ]
33643365 inputs ['input_ids' ] = input_ids
33653366 inputs ['labels' ] = labels
33663367 return inputs , {}
@@ -3372,9 +3373,25 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
33723373 def data_collator (self , batch : List [Dict [str , Any ]], padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
33733374 res = super ().data_collator (batch , padding_to )
33743375 image_embeds = [b ['image_embeds' ] for b in batch if 'image_embeds' in b ]
3376+ num_images = [b ['num_images' ] if 'num_images' in b else 0 for b in batch ]
33753377 if image_embeds :
33763378 res ['image_embeds' ] = torch .concat (image_embeds )
3377- media_offset = [b ['media_offset' ] for b in batch if 'media_offset' in b ]
3379+ media_offset = []
3380+ cusum_offset = 0
3381+
3382+
3383+ for bi ,b in enumerate (batch ):
3384+ if 'media_offset' in b :
3385+ max_sequence_length = res ['input_ids' ].shape [1 ]
3386+ curr_media_offset = b ['media_offset' ]
3387+ if curr_media_offset .shape [1 ]< max_sequence_length :
3388+ padding = curr_media_offset [:,- 1 :,:].expand (curr_media_offset .shape [0 ], max_sequence_length - curr_media_offset .shape [1 ], curr_media_offset .shape [2 ])
3389+ curr_media_offset = torch .concat ([curr_media_offset , padding ], dim = 1 )
3390+ media_offset .append (curr_media_offset + cusum_offset )
3391+ cusum_offset += num_images [bi ]
3392+
3393+ # media_offset = [b['media_offset'] for b in batch if 'media_offset' in b]
3394+
33783395 if media_offset :
33793396 res ['media_offset' ] = torch .concat (media_offset )
33803397 return res
0 commit comments