Skip to content

Commit 205ad8f

Browse files
princeprideDefTruth
authored andcommitted
[Bagel]Add image edit (vllm-project#588)
Signed-off-by: princepride <[email protected]> Signed-off-by: DefTruth <[email protected]>
1 parent cd62359 commit 205ad8f

2 files changed

Lines changed: 477 additions & 18 deletions

File tree

vllm_omni/diffusion/models/bagel/bagel_transformer.py

Lines changed: 327 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,43 @@
2929

3030
from vllm_omni.diffusion.layers.rope import RotaryEmbedding
3131

32+
33+
def patchify(imgs, p):
34+
"""
35+
imgs: (N, 3, H, W) or (3, H, W)
36+
x: (N, L, patch_size**2 *3) or (L, patch_size**2 *3)
37+
"""
38+
is_batch = imgs.ndim == 4
39+
if not is_batch:
40+
imgs = imgs.unsqueeze(0)
41+
42+
# n: batch, c: channel, h: grid_h, p: patch_h, w: grid_w, q: patch_w
43+
x = imgs.reshape(imgs.shape[0], 3, imgs.shape[2] // p, p, imgs.shape[3] // p, p)
44+
# Permute to (n, grid_h, grid_w, c, patch_h, patch_w) to match Conv2d (c, h, w) flattening
45+
x = torch.einsum("nchpwq->nhwcpq", x)
46+
x = x.reshape(imgs.shape[0], -1, 3 * p**2)
47+
48+
if not is_batch:
49+
x = x.squeeze(0)
50+
return x
51+
52+
53+
class MLPconnector(nn.Module):
54+
def __init__(self, input_dim, output_dim, activation="gelu_pytorch_tanh"):
55+
super().__init__()
56+
self.fc1 = nn.Linear(input_dim, output_dim)
57+
if activation == "gelu":
58+
self.act = nn.GELU()
59+
elif activation == "gelu_pytorch_tanh":
60+
self.act = nn.GELU(approximate="tanh")
61+
else:
62+
self.act = nn.ReLU()
63+
self.fc2 = nn.Linear(output_dim, output_dim)
64+
65+
def forward(self, x):
66+
return self.fc2(self.act(self.fc1(x)))
67+
68+
3269
torch._dynamo.config.cache_size_limit = 512
3370
torch._dynamo.config.accumulated_cache_size_limit = 4096
3471
flex_attention = torch.compile(flex_attention)
@@ -600,51 +637,73 @@ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_pat
600637
class BagelConfig(PretrainedConfig):
601638
def __init__(
602639
self,
640+
visual_gen=True,
641+
visual_und=True,
603642
llm_config=None,
643+
vit_config=None,
604644
vae_config=None,
605645
latent_patch_size=2,
606646
max_latent_size=32,
647+
vit_max_num_patch_per_side=70,
648+
connector_act="gelu_pytorch_tanh",
649+
interpolate_pos=False,
607650
timestep_shift=1.0,
608651
**kwargs,
609652
):
610653
super().__init__(**kwargs)
654+
self.visual_gen = visual_gen
655+
self.visual_und = visual_und
611656
self.llm_config = llm_config
657+
self.vit_config = vit_config
612658
self.vae_config = vae_config
613659
self.latent_patch_size = latent_patch_size
614660
self.max_latent_size = max_latent_size
661+
self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
662+
self.connector_act = connector_act
663+
self.interpolate_pos = interpolate_pos
615664
self.timestep_shift = timestep_shift
616665

617666

618667
class Bagel(torch.nn.Module):
619668
config_class = BagelConfig
620669
base_model_prefix = "bagel"
621670

622-
def __init__(self, language_model, config: BagelConfig):
671+
def __init__(self, language_model, vit_model, config: BagelConfig):
623672
super().__init__()
624673
self.language_model = language_model
625674
self.hidden_size = config.llm_config.hidden_size
626675
self.use_moe = "Mo" in config.llm_config.layer_module
627676
self.num_heads = config.llm_config.num_attention_heads
628677

629-
self.latent_patch_size = config.latent_patch_size
630-
self.timestep_shift = config.timestep_shift
631-
self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
632-
self.max_latent_size = config.max_latent_size
633-
self.latent_channel = config.vae_config.z_channels
634-
self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel
635-
self.time_embedder = TimestepEmbedder(self.hidden_size)
636-
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
637-
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
638-
self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
678+
if config.visual_gen:
679+
self.latent_patch_size = config.latent_patch_size
680+
self.timestep_shift = config.timestep_shift
681+
self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
682+
self.max_latent_size = config.max_latent_size
683+
self.latent_channel = config.vae_config.z_channels
684+
self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel
685+
self.time_embedder = TimestepEmbedder(self.hidden_size)
686+
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
687+
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
688+
self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
689+
690+
if config.visual_und:
691+
self.vit_model = vit_model
692+
self.vit_patch_size = config.vit_config.patch_size
693+
self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
694+
self.vit_hidden_size = config.vit_config.hidden_size
695+
self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
696+
self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
639697

640698
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
641699

642700
self.config = config
643701
self._init_weights()
644702

645703
def _init_weights(self):
646-
nn.init.constant_(self.llm2vae.weight, 0)
647-
nn.init.constant_(self.llm2vae.bias, 0)
704+
if self.config.visual_gen:
705+
nn.init.constant_(self.llm2vae.weight, 0)
706+
nn.init.constant_(self.llm2vae.bias, 0)
648707

649708
def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
650709
packed_text_ids = list()
@@ -713,6 +772,261 @@ def forward_cache_update_text(
713772

714773
return past_key_values
715774

775+
def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
776+
patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
777+
packed_vae_token_indexes = list()
778+
packed_text_ids, packed_text_indexes = list(), list()
779+
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
780+
packed_key_value_indexes = list()
781+
782+
_curr = curr = 0
783+
vae_image_tensors = list()
784+
newlens, new_rope = list(), list()
785+
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
786+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
787+
curr += curr_kvlen
788+
789+
packed_text_ids.append(new_token_ids["start_of_image"])
790+
packed_text_indexes.append(_curr)
791+
packed_indexes.append(curr)
792+
curr += 1
793+
_curr += 1
794+
795+
image_tensor = transforms(image)
796+
vae_image_tensors.append(image_tensor)
797+
vae_position_ids = self.get_flattened_position_ids(
798+
image_tensor.size(1),
799+
image_tensor.size(2),
800+
self.latent_downsample,
801+
max_num_patches_per_side=self.max_latent_size,
802+
)
803+
packed_vae_position_ids.append(vae_position_ids)
804+
H, W = image_tensor.shape[1:]
805+
h = H // self.latent_downsample
806+
w = W // self.latent_downsample
807+
patchified_vae_latent_shapes.append((h, w))
808+
809+
num_img_tokens = w * h
810+
packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
811+
packed_indexes.extend(range(curr, curr + num_img_tokens))
812+
curr += num_img_tokens
813+
_curr += num_img_tokens
814+
815+
packed_text_ids.append(new_token_ids["end_of_image"])
816+
packed_text_indexes.append(_curr)
817+
packed_indexes.append(curr)
818+
curr += 1
819+
_curr += 1
820+
821+
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
822+
packed_seqlens.append(num_img_tokens + 2)
823+
newlens.append(curr_kvlen + num_img_tokens + 2)
824+
new_rope.append(curr_position_id + 1)
825+
826+
image_sizes = [item.shape for item in vae_image_tensors]
827+
max_image_size = [max(item) for item in list(zip(*image_sizes))]
828+
padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
829+
for i, image_tensor in enumerate(vae_image_tensors):
830+
padded_images[i, :, : image_tensor.shape[1], : image_tensor.shape[2]] = image_tensor
831+
832+
generation_input = {
833+
"padded_images": padded_images,
834+
"patchified_vae_latent_shapes": patchified_vae_latent_shapes,
835+
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
836+
"packed_timesteps": torch.tensor([timestep]),
837+
"packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
838+
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
839+
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
840+
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
841+
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
842+
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
843+
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
844+
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
845+
}
846+
847+
return generation_input, newlens, new_rope
848+
849+
@torch.no_grad
850+
def forward_cache_update_vae(
851+
self,
852+
vae_model,
853+
past_key_values: NaiveCache,
854+
padded_images: torch.Tensor,
855+
patchified_vae_latent_shapes: list,
856+
packed_vae_position_ids: torch.LongTensor,
857+
packed_timesteps: torch.Tensor,
858+
packed_vae_token_indexes: torch.LongTensor,
859+
packed_text_ids: torch.LongTensor,
860+
packed_text_indexes: torch.LongTensor,
861+
packed_position_ids: torch.LongTensor,
862+
packed_seqlens: torch.IntTensor,
863+
packed_indexes: torch.LongTensor,
864+
key_values_lens: torch.IntTensor,
865+
packed_key_value_indexes: torch.Tensor,
866+
):
867+
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
868+
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
869+
packed_sequence[packed_text_indexes] = packed_text_embedding
870+
871+
padded_latent = vae_model.encode(padded_images)
872+
873+
p = self.latent_patch_size
874+
packed_latent = list()
875+
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
876+
latent = latent[:, : h * p, : w * p].reshape(self.latent_channel, h, p, w, p)
877+
latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
878+
packed_latent.append(latent)
879+
packed_latent = torch.cat(packed_latent, dim=0)
880+
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
881+
packed_timestep_embeds = self.time_embedder(packed_timesteps)
882+
packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
883+
if packed_latent.dtype != packed_sequence.dtype:
884+
packed_latent = packed_latent.to(packed_sequence.dtype)
885+
packed_sequence[packed_vae_token_indexes] = packed_latent
886+
887+
extra_inputs = {}
888+
if self.use_moe:
889+
extra_inputs = {
890+
"mode": "gen",
891+
"packed_vae_token_indexes": packed_vae_token_indexes,
892+
"packed_text_indexes": packed_text_indexes,
893+
}
894+
895+
output = self.language_model.forward(
896+
packed_query_sequence=packed_sequence,
897+
query_lens=packed_seqlens,
898+
packed_query_position_ids=packed_position_ids,
899+
packed_query_indexes=packed_indexes,
900+
past_key_values=past_key_values,
901+
key_values_lens=key_values_lens,
902+
packed_key_value_indexes=packed_key_value_indexes,
903+
update_past_key_values=True,
904+
is_causal=False,
905+
**extra_inputs,
906+
)
907+
past_key_values = output.past_key_values
908+
909+
return past_key_values
910+
911+
def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
912+
packed_vit_token_indexes = list()
913+
vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
914+
packed_text_ids, packed_text_indexes = list(), list()
915+
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
916+
packed_key_value_indexes = list()
917+
918+
_curr = curr = 0
919+
newlens, new_rope = list(), list()
920+
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
921+
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
922+
curr += curr_kvlen
923+
924+
packed_text_ids.append(new_token_ids["start_of_image"])
925+
packed_text_indexes.append(_curr)
926+
packed_indexes.append(curr)
927+
curr += 1
928+
_curr += 1
929+
930+
image_tensor = transforms(image)
931+
vit_position_ids = self.get_flattened_position_ids(
932+
image_tensor.size(1),
933+
image_tensor.size(2),
934+
self.vit_patch_size,
935+
max_num_patches_per_side=self.vit_max_num_patch_per_side,
936+
)
937+
vit_tokens = patchify(image_tensor, self.vit_patch_size)
938+
packed_vit_tokens.append(vit_tokens)
939+
num_img_tokens = vit_tokens.shape[0]
940+
packed_vit_position_ids.append(vit_position_ids)
941+
vit_token_seqlens.append(num_img_tokens)
942+
packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
943+
packed_indexes.extend(range(curr, curr + num_img_tokens))
944+
curr += num_img_tokens
945+
_curr += num_img_tokens
946+
947+
packed_text_ids.append(new_token_ids["end_of_image"])
948+
packed_text_indexes.append(_curr)
949+
packed_indexes.append(curr)
950+
curr += 1
951+
_curr += 1
952+
953+
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
954+
packed_seqlens.append(num_img_tokens + 2)
955+
newlens.append(curr_kvlen + num_img_tokens + 2)
956+
new_rope.append(curr_position_id + 1)
957+
958+
generation_input = {
959+
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
960+
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
961+
"vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
962+
"packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
963+
"packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
964+
"packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
965+
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
966+
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
967+
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
968+
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
969+
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
970+
}
971+
972+
return generation_input, newlens, new_rope
973+
974+
@torch.no_grad
975+
def forward_cache_update_vit(
976+
self,
977+
past_key_values: NaiveCache,
978+
packed_text_ids: torch.LongTensor,
979+
packed_text_indexes: torch.LongTensor,
980+
packed_vit_tokens: torch.Tensor,
981+
packed_vit_token_indexes: torch.LongTensor,
982+
packed_vit_position_ids: torch.LongTensor,
983+
vit_token_seqlens: torch.IntTensor,
984+
packed_position_ids: torch.LongTensor,
985+
packed_seqlens: torch.IntTensor,
986+
packed_indexes: torch.LongTensor,
987+
packed_key_value_indexes: torch.LongTensor,
988+
key_values_lens: torch.IntTensor,
989+
):
990+
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
991+
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
992+
packed_sequence[packed_text_indexes] = packed_text_embedding
993+
994+
cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
995+
cu_seqlens = cu_seqlens.to(torch.int32)
996+
max_seqlen = torch.max(vit_token_seqlens).item()
997+
packed_vit_token_embed = self.vit_model(
998+
packed_pixel_values=packed_vit_tokens,
999+
packed_flattened_position_ids=packed_vit_position_ids,
1000+
cu_seqlens=cu_seqlens,
1001+
max_seqlen=max_seqlen,
1002+
)
1003+
packed_vit_token_embed = self.connector(packed_vit_token_embed)
1004+
pos_emb = self.vit_pos_embed(packed_vit_position_ids)
1005+
packed_vit_token_embed = packed_vit_token_embed + pos_emb
1006+
if packed_vit_token_embed.dtype != packed_sequence.dtype:
1007+
packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
1008+
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
1009+
1010+
extra_inputs = {}
1011+
if self.use_moe:
1012+
extra_inputs = {"mode": "und"}
1013+
1014+
output = self.language_model.forward(
1015+
packed_query_sequence=packed_sequence,
1016+
query_lens=packed_seqlens,
1017+
packed_query_position_ids=packed_position_ids,
1018+
packed_query_indexes=packed_indexes,
1019+
past_key_values=past_key_values,
1020+
packed_key_value_indexes=packed_key_value_indexes,
1021+
key_values_lens=key_values_lens,
1022+
update_past_key_values=True,
1023+
is_causal=False,
1024+
**extra_inputs,
1025+
)
1026+
past_key_values = output.past_key_values
1027+
1028+
return past_key_values
1029+
7161030
def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None):
7171031
packed_text_ids, packed_text_indexes = list(), list()
7181032
packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()

0 commit comments

Comments
 (0)