Skip to content

Commit 230a1a2

Browse files
LukeForeverYoungJintao-Huang
authored andcommitted
Fix the issue with media_offset in owl3 when batch_size > 1. (#2100)
* mplugowl3 mediaoffset issue * padding of mediaoffset
1 parent 5791297 commit 230a1a2

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

swift/llm/utils/template.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3361,6 +3361,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
33613361
media_offset = torch.stack([torch.zeros(matrix.shape[0], dtype=torch.long), matrix], dim=-1)[None]
33623362
inputs['_data'] = {'pixel_values': image_inputs['pixel_values']}
33633363
inputs['media_offset'] = media_offset
3364+
inputs['num_images'] = image_inputs['pixel_values'].shape[0]
33643365
inputs['input_ids'] = input_ids
33653366
inputs['labels'] = labels
33663367
return inputs, {}
@@ -3372,9 +3373,25 @@ def _post_encode(self, model, data: Any) -> Dict[str, Any]:
33723373
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
33733374
res = super().data_collator(batch, padding_to)
33743375
image_embeds = [b['image_embeds'] for b in batch if 'image_embeds' in b]
3376+
num_images = [b['num_images'] if 'num_images' in b else 0 for b in batch]
33753377
if image_embeds:
33763378
res['image_embeds'] = torch.concat(image_embeds)
3377-
media_offset = [b['media_offset'] for b in batch if 'media_offset' in b]
3379+
media_offset = []
3380+
cusum_offset = 0
3381+
3382+
3383+
for bi,b in enumerate(batch):
3384+
if 'media_offset' in b:
3385+
max_sequence_length = res['input_ids'].shape[1]
3386+
curr_media_offset = b['media_offset']
3387+
if curr_media_offset.shape[1]<max_sequence_length:
3388+
padding = curr_media_offset[:,-1:,:].expand(curr_media_offset.shape[0], max_sequence_length-curr_media_offset.shape[1], curr_media_offset.shape[2])
3389+
curr_media_offset = torch.concat([curr_media_offset, padding], dim=1)
3390+
media_offset.append(curr_media_offset + cusum_offset)
3391+
cusum_offset += num_images[bi]
3392+
3393+
# media_offset = [b['media_offset'] for b in batch if 'media_offset' in b]
3394+
33783395
if media_offset:
33793396
res['media_offset'] = torch.concat(media_offset)
33803397
return res

0 commit comments

Comments
 (0)