diff --git a/recipes/dev/generate_v2.py b/recipes/dev/generate_v2.py index e63ea2dcb0..3ce95a9fdf 100644 --- a/recipes/dev/generate_v2.py +++ b/recipes/dev/generate_v2.py @@ -152,7 +152,10 @@ def generate(self, cfg: DictConfig): batch = {} if is_multimodal_input: batch = padded_collate_tiled_images_and_mask( - [model_inputs], pad_direction="left", pad_max_images=1 + [model_inputs], + pad_direction="left", + pad_max_images=1, + pad_max_tiles=self.model_transform.max_num_tiles, ) batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] prompt = batch.pop("tokens").to(self._device) diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 590e4f902a..68503ff63c 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -187,6 +187,7 @@ def tok_batch_multimodal_encode( all_encoded_messages, pad_direction="left", pad_max_images=self._max_images_per_sample, + pad_max_tiles=self._transform.max_num_tiles, ) utils.batch_to_device(tok_batch, self.device) diff --git a/tests/torchtune/data/test_collate.py b/tests/torchtune/data/test_collate.py index ca4bccd5d5..02477ad7e5 100644 --- a/tests/torchtune/data/test_collate.py +++ b/tests/torchtune/data/test_collate.py @@ -56,26 +56,31 @@ def test_batch_pad_sequence(self): class TestPaddedCollateTiledImagesAndMask: + img_shape = 1, 1, 1 + tokens_per_tile = 5 + @pytest.fixture def batch(self): + c, h, w = self.img_shape + s = self.tokens_per_tile return [ { "tokens": [1, 2, 1, 3], "labels": [4, 5, 6, 7], "encoder_input": { - "images": [torch.ones(2, 1, 1, 1), torch.ones(3, 1, 1, 1)], + "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], }, - "encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)], + "encoder_mask": [torch.ones(4, s * 2), torch.ones(4, s * 3)], }, { "tokens": [1, 4], "labels": [8, 9], "encoder_input": { - "images": [torch.ones(4, 1, 1, 1)], + "images": [torch.ones(4, c, h, w)], "aspect_ratio": [torch.tensor([2, 2])], }, - "encoder_mask": [torch.ones(2, 5 * 4)], + "encoder_mask": [torch.ones(2, s * 4)], }, ] @@ -83,6 +88,9 @@ def test_right_pad_sequence(self, batch): actual = padded_collate_tiled_images_and_mask( batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="right" ) + imgs, tiles = actual["encoder_input"]["images"].shape[1:3] + seq_len = actual["encoder_mask"].shape[-1] + assert imgs * tiles * self.tokens_per_tile == seq_len mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1) mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1) @@ -126,15 +134,23 @@ def test_left_pad_sequence(self, batch): ignore_idx=-100, pad_direction="left", pad_max_images=4, + pad_max_tiles=5, ) + imgs, tiles = actual["encoder_input"]["images"].shape[1:3] + seq_len = actual["encoder_mask"].shape[-1] + assert 5 * 4 * self.tokens_per_tile == seq_len - mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1) - mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1) + # pad 3 extra tiles + mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1) + # pad 2 extra tiles + mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1) + # Left pad text tokens mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0) + mask_3 = F.pad(mask_3, (0, 5), value=0) # pad 5th tile sample_1 = torch.stack([mask_1, mask_2]) - sample_2 = torch.stack([mask_3, torch.zeros(4, 20)]) + sample_2 = torch.stack([mask_3, torch.zeros(4, 25)]) expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1) - expected_mask = F.pad(expected_mask, (0, 40), value=0) + expected_mask = F.pad(expected_mask, (0, 50), value=0) expected = { "tokens": torch.tensor([[1, 2, 1, 3], [0, 0, 1, 4]]), @@ -142,12 +158,12 @@ def test_left_pad_sequence(self, batch): "images": torch.tensor( [ [ - [[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]], - [[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]], + [[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], + [[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]], ], [ - [[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]]], - [[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], + [[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]], + [[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]], ], ] ), diff --git a/tests/torchtune/modules/transforms/test_transforms.py b/tests/torchtune/modules/transforms/test_transforms.py index 4bdc137dcf..1e24a06d41 100644 --- a/tests/torchtune/modules/transforms/test_transforms.py +++ b/tests/torchtune/modules/transforms/test_transforms.py @@ -10,7 +10,6 @@ IMAGE_TOKEN_ID = 1 -MAX_NUM_TILES = 4 class TestVisionCrossAttentionMask: @@ -54,7 +53,6 @@ def cross_attn_mask_transform(self, tile_size, patch_size): tile_size=tile_size, patch_size=patch_size, image_token_id=IMAGE_TOKEN_ID, - max_num_tiles=MAX_NUM_TILES, ) def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens): @@ -89,7 +87,7 @@ def test_inference_call( sample.update(dummy_kwargs) actual = cross_attn_mask_transform(sample, inference=True) expected = [ - torch.zeros(len(tokens), image_num_tokens * 2, dtype=torch.bool) + torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool) for _ in range(len(images)) ] expected[0][2:6, :image_num_tokens] = True diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 055ab77350..5157f4a7fa 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -426,7 +426,8 @@ def padded_collate_tiled_images_and_mask( if pad_max_images is not None: _, _, img_seq = concat_masks.shape concat_masks = F.pad( - concat_masks, (0, pad_max_images * image_seq_len - img_seq) + concat_masks, + (0, pad_max_images * max_num_tiles * tokens_per_tile - img_seq), ) batch_dict = { diff --git a/torchtune/models/llama3_2_vision/_transform.py b/torchtune/models/llama3_2_vision/_transform.py index 4dc4f781e9..cee3478b6c 100644 --- a/torchtune/models/llama3_2_vision/_transform.py +++ b/torchtune/models/llama3_2_vision/_transform.py @@ -93,11 +93,11 @@ def __init__( tile_size=tile_size, patch_size=patch_size, image_token_id=self.tokenizer.image_id, - max_num_tiles=max_num_tiles, ) self.stop_tokens = self.tokenizer.stop_tokens self.max_seq_len = max_seq_len + self.max_num_tiles = max_num_tiles self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1) self.prompt_template = prompt_template self.pad_id = self.tokenizer.pad_id diff --git a/torchtune/modules/transforms/_transforms.py b/torchtune/modules/transforms/_transforms.py index 006f224144..fd1e60e33d 100644 --- a/torchtune/modules/transforms/_transforms.py +++ b/torchtune/modules/transforms/_transforms.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Mapping, Optional, Protocol +from typing import Any, List, Mapping, Protocol import torch @@ -57,8 +57,6 @@ class VisionCrossAttentionMask(Transform): E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches with shape (40, 40) each. image_token_id (int): Token ID of the image special token. - max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to - pad mask during inference. Defaults to None """ def __init__( @@ -66,12 +64,10 @@ def __init__( tile_size: int, patch_size: int, image_token_id: int, - max_num_tiles: Optional[int] = None, ): patch_grid_size = tile_size // patch_size self.patches_per_tile = patch_grid_size**2 self.image_token_id = image_token_id - self.max_num_tiles = max_num_tiles def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]: """ @@ -163,9 +159,6 @@ def __call__( # which can vary based on number of tiles since they are not yet tile padded. # The masks are padded and concatenated together in the batch collator text_seq_len = len(tokens) - max_image_size = None - if inference and self.max_num_tiles is not None: - max_image_size = self.max_num_tiles * (self.patches_per_tile + 1) masks = [] for image_num, interval in enumerate(intervals): # Identify what part of text sequence should be attended @@ -178,9 +171,7 @@ def __call__( # to a single image, so text tokens attend to all the image's tokens. # The mask is text_seq_len x mask_image_size if defined, otherwise # it uses current text/image sequence lengths. - mask = torch.zeros( - text_seq_len, max_image_size or image_seq_len, dtype=torch.bool - ) + mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool) mask[start:end, :image_seq_len] = True masks.append(mask)