|
29 | 29 |
|
30 | 30 | from vllm_omni.diffusion.layers.rope import RotaryEmbedding |
31 | 31 |
|
| 32 | + |
| 33 | +def patchify(imgs, p): |
| 34 | + """ |
| 35 | + imgs: (N, 3, H, W) or (3, H, W) |
| 36 | + x: (N, L, patch_size**2 *3) or (L, patch_size**2 *3) |
| 37 | + """ |
| 38 | + is_batch = imgs.ndim == 4 |
| 39 | + if not is_batch: |
| 40 | + imgs = imgs.unsqueeze(0) |
| 41 | + |
| 42 | + # n: batch, c: channel, h: grid_h, p: patch_h, w: grid_w, q: patch_w |
| 43 | + x = imgs.reshape(imgs.shape[0], 3, imgs.shape[2] // p, p, imgs.shape[3] // p, p) |
| 44 | + # Permute to (n, grid_h, grid_w, c, patch_h, patch_w) to match Conv2d (c, h, w) flattening |
| 45 | + x = torch.einsum("nchpwq->nhwcpq", x) |
| 46 | + x = x.reshape(imgs.shape[0], -1, 3 * p**2) |
| 47 | + |
| 48 | + if not is_batch: |
| 49 | + x = x.squeeze(0) |
| 50 | + return x |
| 51 | + |
| 52 | + |
| 53 | +class MLPconnector(nn.Module): |
| 54 | + def __init__(self, input_dim, output_dim, activation="gelu_pytorch_tanh"): |
| 55 | + super().__init__() |
| 56 | + self.fc1 = nn.Linear(input_dim, output_dim) |
| 57 | + if activation == "gelu": |
| 58 | + self.act = nn.GELU() |
| 59 | + elif activation == "gelu_pytorch_tanh": |
| 60 | + self.act = nn.GELU(approximate="tanh") |
| 61 | + else: |
| 62 | + self.act = nn.ReLU() |
| 63 | + self.fc2 = nn.Linear(output_dim, output_dim) |
| 64 | + |
| 65 | + def forward(self, x): |
| 66 | + return self.fc2(self.act(self.fc1(x))) |
| 67 | + |
| 68 | + |
32 | 69 | torch._dynamo.config.cache_size_limit = 512 |
33 | 70 | torch._dynamo.config.accumulated_cache_size_limit = 4096 |
34 | 71 | flex_attention = torch.compile(flex_attention) |
@@ -600,51 +637,73 @@ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_pat |
600 | 637 | class BagelConfig(PretrainedConfig): |
601 | 638 | def __init__( |
602 | 639 | self, |
| 640 | + visual_gen=True, |
| 641 | + visual_und=True, |
603 | 642 | llm_config=None, |
| 643 | + vit_config=None, |
604 | 644 | vae_config=None, |
605 | 645 | latent_patch_size=2, |
606 | 646 | max_latent_size=32, |
| 647 | + vit_max_num_patch_per_side=70, |
| 648 | + connector_act="gelu_pytorch_tanh", |
| 649 | + interpolate_pos=False, |
607 | 650 | timestep_shift=1.0, |
608 | 651 | **kwargs, |
609 | 652 | ): |
610 | 653 | super().__init__(**kwargs) |
| 654 | + self.visual_gen = visual_gen |
| 655 | + self.visual_und = visual_und |
611 | 656 | self.llm_config = llm_config |
| 657 | + self.vit_config = vit_config |
612 | 658 | self.vae_config = vae_config |
613 | 659 | self.latent_patch_size = latent_patch_size |
614 | 660 | self.max_latent_size = max_latent_size |
| 661 | + self.vit_max_num_patch_per_side = vit_max_num_patch_per_side |
| 662 | + self.connector_act = connector_act |
| 663 | + self.interpolate_pos = interpolate_pos |
615 | 664 | self.timestep_shift = timestep_shift |
616 | 665 |
|
617 | 666 |
|
618 | 667 | class Bagel(torch.nn.Module): |
619 | 668 | config_class = BagelConfig |
620 | 669 | base_model_prefix = "bagel" |
621 | 670 |
|
622 | | - def __init__(self, language_model, config: BagelConfig): |
| 671 | + def __init__(self, language_model, vit_model, config: BagelConfig): |
623 | 672 | super().__init__() |
624 | 673 | self.language_model = language_model |
625 | 674 | self.hidden_size = config.llm_config.hidden_size |
626 | 675 | self.use_moe = "Mo" in config.llm_config.layer_module |
627 | 676 | self.num_heads = config.llm_config.num_attention_heads |
628 | 677 |
|
629 | | - self.latent_patch_size = config.latent_patch_size |
630 | | - self.timestep_shift = config.timestep_shift |
631 | | - self.latent_downsample = config.vae_config.downsample * config.latent_patch_size |
632 | | - self.max_latent_size = config.max_latent_size |
633 | | - self.latent_channel = config.vae_config.z_channels |
634 | | - self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel |
635 | | - self.time_embedder = TimestepEmbedder(self.hidden_size) |
636 | | - self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) |
637 | | - self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) |
638 | | - self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) |
| 678 | + if config.visual_gen: |
| 679 | + self.latent_patch_size = config.latent_patch_size |
| 680 | + self.timestep_shift = config.timestep_shift |
| 681 | + self.latent_downsample = config.vae_config.downsample * config.latent_patch_size |
| 682 | + self.max_latent_size = config.max_latent_size |
| 683 | + self.latent_channel = config.vae_config.z_channels |
| 684 | + self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel |
| 685 | + self.time_embedder = TimestepEmbedder(self.hidden_size) |
| 686 | + self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size) |
| 687 | + self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim) |
| 688 | + self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size) |
| 689 | + |
| 690 | + if config.visual_und: |
| 691 | + self.vit_model = vit_model |
| 692 | + self.vit_patch_size = config.vit_config.patch_size |
| 693 | + self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side |
| 694 | + self.vit_hidden_size = config.vit_config.hidden_size |
| 695 | + self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act) |
| 696 | + self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size) |
639 | 697 |
|
640 | 698 | self.get_flattened_position_ids = get_flattened_position_ids_extrapolate |
641 | 699 |
|
642 | 700 | self.config = config |
643 | 701 | self._init_weights() |
644 | 702 |
|
645 | 703 | def _init_weights(self): |
646 | | - nn.init.constant_(self.llm2vae.weight, 0) |
647 | | - nn.init.constant_(self.llm2vae.bias, 0) |
| 704 | + if self.config.visual_gen: |
| 705 | + nn.init.constant_(self.llm2vae.weight, 0) |
| 706 | + nn.init.constant_(self.llm2vae.bias, 0) |
648 | 707 |
|
649 | 708 | def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids): |
650 | 709 | packed_text_ids = list() |
@@ -713,6 +772,261 @@ def forward_cache_update_text( |
713 | 772 |
|
714 | 773 | return past_key_values |
715 | 774 |
|
| 775 | + def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0): |
| 776 | + patchified_vae_latent_shapes, packed_vae_position_ids = list(), list() |
| 777 | + packed_vae_token_indexes = list() |
| 778 | + packed_text_ids, packed_text_indexes = list(), list() |
| 779 | + packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() |
| 780 | + packed_key_value_indexes = list() |
| 781 | + |
| 782 | + _curr = curr = 0 |
| 783 | + vae_image_tensors = list() |
| 784 | + newlens, new_rope = list(), list() |
| 785 | + for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): |
| 786 | + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) |
| 787 | + curr += curr_kvlen |
| 788 | + |
| 789 | + packed_text_ids.append(new_token_ids["start_of_image"]) |
| 790 | + packed_text_indexes.append(_curr) |
| 791 | + packed_indexes.append(curr) |
| 792 | + curr += 1 |
| 793 | + _curr += 1 |
| 794 | + |
| 795 | + image_tensor = transforms(image) |
| 796 | + vae_image_tensors.append(image_tensor) |
| 797 | + vae_position_ids = self.get_flattened_position_ids( |
| 798 | + image_tensor.size(1), |
| 799 | + image_tensor.size(2), |
| 800 | + self.latent_downsample, |
| 801 | + max_num_patches_per_side=self.max_latent_size, |
| 802 | + ) |
| 803 | + packed_vae_position_ids.append(vae_position_ids) |
| 804 | + H, W = image_tensor.shape[1:] |
| 805 | + h = H // self.latent_downsample |
| 806 | + w = W // self.latent_downsample |
| 807 | + patchified_vae_latent_shapes.append((h, w)) |
| 808 | + |
| 809 | + num_img_tokens = w * h |
| 810 | + packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens)) |
| 811 | + packed_indexes.extend(range(curr, curr + num_img_tokens)) |
| 812 | + curr += num_img_tokens |
| 813 | + _curr += num_img_tokens |
| 814 | + |
| 815 | + packed_text_ids.append(new_token_ids["end_of_image"]) |
| 816 | + packed_text_indexes.append(_curr) |
| 817 | + packed_indexes.append(curr) |
| 818 | + curr += 1 |
| 819 | + _curr += 1 |
| 820 | + |
| 821 | + packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) |
| 822 | + packed_seqlens.append(num_img_tokens + 2) |
| 823 | + newlens.append(curr_kvlen + num_img_tokens + 2) |
| 824 | + new_rope.append(curr_position_id + 1) |
| 825 | + |
| 826 | + image_sizes = [item.shape for item in vae_image_tensors] |
| 827 | + max_image_size = [max(item) for item in list(zip(*image_sizes))] |
| 828 | + padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size)) |
| 829 | + for i, image_tensor in enumerate(vae_image_tensors): |
| 830 | + padded_images[i, :, : image_tensor.shape[1], : image_tensor.shape[2]] = image_tensor |
| 831 | + |
| 832 | + generation_input = { |
| 833 | + "padded_images": padded_images, |
| 834 | + "patchified_vae_latent_shapes": patchified_vae_latent_shapes, |
| 835 | + "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), |
| 836 | + "packed_timesteps": torch.tensor([timestep]), |
| 837 | + "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), |
| 838 | + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), |
| 839 | + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), |
| 840 | + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), |
| 841 | + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), |
| 842 | + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), |
| 843 | + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), |
| 844 | + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), |
| 845 | + } |
| 846 | + |
| 847 | + return generation_input, newlens, new_rope |
| 848 | + |
| 849 | + @torch.no_grad |
| 850 | + def forward_cache_update_vae( |
| 851 | + self, |
| 852 | + vae_model, |
| 853 | + past_key_values: NaiveCache, |
| 854 | + padded_images: torch.Tensor, |
| 855 | + patchified_vae_latent_shapes: list, |
| 856 | + packed_vae_position_ids: torch.LongTensor, |
| 857 | + packed_timesteps: torch.Tensor, |
| 858 | + packed_vae_token_indexes: torch.LongTensor, |
| 859 | + packed_text_ids: torch.LongTensor, |
| 860 | + packed_text_indexes: torch.LongTensor, |
| 861 | + packed_position_ids: torch.LongTensor, |
| 862 | + packed_seqlens: torch.IntTensor, |
| 863 | + packed_indexes: torch.LongTensor, |
| 864 | + key_values_lens: torch.IntTensor, |
| 865 | + packed_key_value_indexes: torch.Tensor, |
| 866 | + ): |
| 867 | + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) |
| 868 | + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) |
| 869 | + packed_sequence[packed_text_indexes] = packed_text_embedding |
| 870 | + |
| 871 | + padded_latent = vae_model.encode(padded_images) |
| 872 | + |
| 873 | + p = self.latent_patch_size |
| 874 | + packed_latent = list() |
| 875 | + for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes): |
| 876 | + latent = latent[:, : h * p, : w * p].reshape(self.latent_channel, h, p, w, p) |
| 877 | + latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel) |
| 878 | + packed_latent.append(latent) |
| 879 | + packed_latent = torch.cat(packed_latent, dim=0) |
| 880 | + packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids) |
| 881 | + packed_timestep_embeds = self.time_embedder(packed_timesteps) |
| 882 | + packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed |
| 883 | + if packed_latent.dtype != packed_sequence.dtype: |
| 884 | + packed_latent = packed_latent.to(packed_sequence.dtype) |
| 885 | + packed_sequence[packed_vae_token_indexes] = packed_latent |
| 886 | + |
| 887 | + extra_inputs = {} |
| 888 | + if self.use_moe: |
| 889 | + extra_inputs = { |
| 890 | + "mode": "gen", |
| 891 | + "packed_vae_token_indexes": packed_vae_token_indexes, |
| 892 | + "packed_text_indexes": packed_text_indexes, |
| 893 | + } |
| 894 | + |
| 895 | + output = self.language_model.forward( |
| 896 | + packed_query_sequence=packed_sequence, |
| 897 | + query_lens=packed_seqlens, |
| 898 | + packed_query_position_ids=packed_position_ids, |
| 899 | + packed_query_indexes=packed_indexes, |
| 900 | + past_key_values=past_key_values, |
| 901 | + key_values_lens=key_values_lens, |
| 902 | + packed_key_value_indexes=packed_key_value_indexes, |
| 903 | + update_past_key_values=True, |
| 904 | + is_causal=False, |
| 905 | + **extra_inputs, |
| 906 | + ) |
| 907 | + past_key_values = output.past_key_values |
| 908 | + |
| 909 | + return past_key_values |
| 910 | + |
| 911 | + def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids): |
| 912 | + packed_vit_token_indexes = list() |
| 913 | + vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list() |
| 914 | + packed_text_ids, packed_text_indexes = list(), list() |
| 915 | + packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list() |
| 916 | + packed_key_value_indexes = list() |
| 917 | + |
| 918 | + _curr = curr = 0 |
| 919 | + newlens, new_rope = list(), list() |
| 920 | + for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope): |
| 921 | + packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) |
| 922 | + curr += curr_kvlen |
| 923 | + |
| 924 | + packed_text_ids.append(new_token_ids["start_of_image"]) |
| 925 | + packed_text_indexes.append(_curr) |
| 926 | + packed_indexes.append(curr) |
| 927 | + curr += 1 |
| 928 | + _curr += 1 |
| 929 | + |
| 930 | + image_tensor = transforms(image) |
| 931 | + vit_position_ids = self.get_flattened_position_ids( |
| 932 | + image_tensor.size(1), |
| 933 | + image_tensor.size(2), |
| 934 | + self.vit_patch_size, |
| 935 | + max_num_patches_per_side=self.vit_max_num_patch_per_side, |
| 936 | + ) |
| 937 | + vit_tokens = patchify(image_tensor, self.vit_patch_size) |
| 938 | + packed_vit_tokens.append(vit_tokens) |
| 939 | + num_img_tokens = vit_tokens.shape[0] |
| 940 | + packed_vit_position_ids.append(vit_position_ids) |
| 941 | + vit_token_seqlens.append(num_img_tokens) |
| 942 | + packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens)) |
| 943 | + packed_indexes.extend(range(curr, curr + num_img_tokens)) |
| 944 | + curr += num_img_tokens |
| 945 | + _curr += num_img_tokens |
| 946 | + |
| 947 | + packed_text_ids.append(new_token_ids["end_of_image"]) |
| 948 | + packed_text_indexes.append(_curr) |
| 949 | + packed_indexes.append(curr) |
| 950 | + curr += 1 |
| 951 | + _curr += 1 |
| 952 | + |
| 953 | + packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2)) |
| 954 | + packed_seqlens.append(num_img_tokens + 2) |
| 955 | + newlens.append(curr_kvlen + num_img_tokens + 2) |
| 956 | + new_rope.append(curr_position_id + 1) |
| 957 | + |
| 958 | + generation_input = { |
| 959 | + "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), |
| 960 | + "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), |
| 961 | + "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int), |
| 962 | + "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0), |
| 963 | + "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0), |
| 964 | + "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long), |
| 965 | + "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), |
| 966 | + "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), |
| 967 | + "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), |
| 968 | + "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), |
| 969 | + "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), |
| 970 | + } |
| 971 | + |
| 972 | + return generation_input, newlens, new_rope |
| 973 | + |
| 974 | + @torch.no_grad |
| 975 | + def forward_cache_update_vit( |
| 976 | + self, |
| 977 | + past_key_values: NaiveCache, |
| 978 | + packed_text_ids: torch.LongTensor, |
| 979 | + packed_text_indexes: torch.LongTensor, |
| 980 | + packed_vit_tokens: torch.Tensor, |
| 981 | + packed_vit_token_indexes: torch.LongTensor, |
| 982 | + packed_vit_position_ids: torch.LongTensor, |
| 983 | + vit_token_seqlens: torch.IntTensor, |
| 984 | + packed_position_ids: torch.LongTensor, |
| 985 | + packed_seqlens: torch.IntTensor, |
| 986 | + packed_indexes: torch.LongTensor, |
| 987 | + packed_key_value_indexes: torch.LongTensor, |
| 988 | + key_values_lens: torch.IntTensor, |
| 989 | + ): |
| 990 | + packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids) |
| 991 | + packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size)) |
| 992 | + packed_sequence[packed_text_indexes] = packed_text_embedding |
| 993 | + |
| 994 | + cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0)) |
| 995 | + cu_seqlens = cu_seqlens.to(torch.int32) |
| 996 | + max_seqlen = torch.max(vit_token_seqlens).item() |
| 997 | + packed_vit_token_embed = self.vit_model( |
| 998 | + packed_pixel_values=packed_vit_tokens, |
| 999 | + packed_flattened_position_ids=packed_vit_position_ids, |
| 1000 | + cu_seqlens=cu_seqlens, |
| 1001 | + max_seqlen=max_seqlen, |
| 1002 | + ) |
| 1003 | + packed_vit_token_embed = self.connector(packed_vit_token_embed) |
| 1004 | + pos_emb = self.vit_pos_embed(packed_vit_position_ids) |
| 1005 | + packed_vit_token_embed = packed_vit_token_embed + pos_emb |
| 1006 | + if packed_vit_token_embed.dtype != packed_sequence.dtype: |
| 1007 | + packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype) |
| 1008 | + packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed |
| 1009 | + |
| 1010 | + extra_inputs = {} |
| 1011 | + if self.use_moe: |
| 1012 | + extra_inputs = {"mode": "und"} |
| 1013 | + |
| 1014 | + output = self.language_model.forward( |
| 1015 | + packed_query_sequence=packed_sequence, |
| 1016 | + query_lens=packed_seqlens, |
| 1017 | + packed_query_position_ids=packed_position_ids, |
| 1018 | + packed_query_indexes=packed_indexes, |
| 1019 | + past_key_values=past_key_values, |
| 1020 | + packed_key_value_indexes=packed_key_value_indexes, |
| 1021 | + key_values_lens=key_values_lens, |
| 1022 | + update_past_key_values=True, |
| 1023 | + is_causal=False, |
| 1024 | + **extra_inputs, |
| 1025 | + ) |
| 1026 | + past_key_values = output.past_key_values |
| 1027 | + |
| 1028 | + return past_key_values |
| 1029 | + |
716 | 1030 | def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None): |
717 | 1031 | packed_text_ids, packed_text_indexes = list(), list() |
718 | 1032 | packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() |
|
0 commit comments