Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 327 additions & 13 deletions vllm_omni/diffusion/models/bagel/bagel_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,43 @@

from vllm_omni.diffusion.layers.rope import RotaryEmbedding


def patchify(imgs, p):
"""
imgs: (N, 3, H, W) or (3, H, W)
x: (N, L, patch_size**2 *3) or (L, patch_size**2 *3)
"""
is_batch = imgs.ndim == 4
if not is_batch:
imgs = imgs.unsqueeze(0)

# n: batch, c: channel, h: grid_h, p: patch_h, w: grid_w, q: patch_w
x = imgs.reshape(imgs.shape[0], 3, imgs.shape[2] // p, p, imgs.shape[3] // p, p)
# Permute to (n, grid_h, grid_w, c, patch_h, patch_w) to match Conv2d (c, h, w) flattening
x = torch.einsum("nchpwq->nhwcpq", x)
x = x.reshape(imgs.shape[0], -1, 3 * p**2)

if not is_batch:
x = x.squeeze(0)
return x


class MLPconnector(nn.Module):
def __init__(self, input_dim, output_dim, activation="gelu_pytorch_tanh"):
super().__init__()
self.fc1 = nn.Linear(input_dim, output_dim)
if activation == "gelu":
self.act = nn.GELU()
elif activation == "gelu_pytorch_tanh":
self.act = nn.GELU(approximate="tanh")
else:
self.act = nn.ReLU()
self.fc2 = nn.Linear(output_dim, output_dim)

def forward(self, x):
return self.fc2(self.act(self.fc1(x)))


torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.accumulated_cache_size_limit = 4096
flex_attention = torch.compile(flex_attention)
Expand Down Expand Up @@ -600,51 +637,73 @@ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_pat
class BagelConfig(PretrainedConfig):
def __init__(
self,
visual_gen=True,
visual_und=True,
llm_config=None,
vit_config=None,
vae_config=None,
latent_patch_size=2,
max_latent_size=32,
vit_max_num_patch_per_side=70,
connector_act="gelu_pytorch_tanh",
interpolate_pos=False,
timestep_shift=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.visual_gen = visual_gen
self.visual_und = visual_und
self.llm_config = llm_config
self.vit_config = vit_config
self.vae_config = vae_config
self.latent_patch_size = latent_patch_size
self.max_latent_size = max_latent_size
self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
self.connector_act = connector_act
self.interpolate_pos = interpolate_pos
self.timestep_shift = timestep_shift


class Bagel(torch.nn.Module):
config_class = BagelConfig
base_model_prefix = "bagel"

def __init__(self, language_model, config: BagelConfig):
def __init__(self, language_model, vit_model, config: BagelConfig):
super().__init__()
self.language_model = language_model
self.hidden_size = config.llm_config.hidden_size
self.use_moe = "Mo" in config.llm_config.layer_module
self.num_heads = config.llm_config.num_attention_heads

self.latent_patch_size = config.latent_patch_size
self.timestep_shift = config.timestep_shift
self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
self.max_latent_size = config.max_latent_size
self.latent_channel = config.vae_config.z_channels
self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel
self.time_embedder = TimestepEmbedder(self.hidden_size)
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
if config.visual_gen:
self.latent_patch_size = config.latent_patch_size
self.timestep_shift = config.timestep_shift
self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
self.max_latent_size = config.max_latent_size
self.latent_channel = config.vae_config.z_channels
self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel
self.time_embedder = TimestepEmbedder(self.hidden_size)
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)

if config.visual_und:
self.vit_model = vit_model
self.vit_patch_size = config.vit_config.patch_size
self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
self.vit_hidden_size = config.vit_config.hidden_size
self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)

self.get_flattened_position_ids = get_flattened_position_ids_extrapolate

self.config = config
self._init_weights()

def _init_weights(self):
nn.init.constant_(self.llm2vae.weight, 0)
nn.init.constant_(self.llm2vae.bias, 0)
if self.config.visual_gen:
nn.init.constant_(self.llm2vae.weight, 0)
nn.init.constant_(self.llm2vae.bias, 0)

def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
packed_text_ids = list()
Expand Down Expand Up @@ -713,6 +772,261 @@ def forward_cache_update_text(

return past_key_values

def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
packed_vae_token_indexes = list()
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()

_curr = curr = 0
vae_image_tensors = list()
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen

packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1

image_tensor = transforms(image)
vae_image_tensors.append(image_tensor)
vae_position_ids = self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.latent_downsample,
max_num_patches_per_side=self.max_latent_size,
)
packed_vae_position_ids.append(vae_position_ids)
H, W = image_tensor.shape[1:]
h = H // self.latent_downsample
w = W // self.latent_downsample
patchified_vae_latent_shapes.append((h, w))

num_img_tokens = w * h
packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens

packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1

packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)

image_sizes = [item.shape for item in vae_image_tensors]
max_image_size = [max(item) for item in list(zip(*image_sizes))]
padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
for i, image_tensor in enumerate(vae_image_tensors):
padded_images[i, :, : image_tensor.shape[1], : image_tensor.shape[2]] = image_tensor

generation_input = {
"padded_images": padded_images,
"patchified_vae_latent_shapes": patchified_vae_latent_shapes,
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
"packed_timesteps": torch.tensor([timestep]),
"packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}

return generation_input, newlens, new_rope

@torch.no_grad
def forward_cache_update_vae(
self,
vae_model,
past_key_values: NaiveCache,
padded_images: torch.Tensor,
patchified_vae_latent_shapes: list,
packed_vae_position_ids: torch.LongTensor,
packed_timesteps: torch.Tensor,
packed_vae_token_indexes: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_key_value_indexes: torch.Tensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding

padded_latent = vae_model.encode(padded_images)

p = self.latent_patch_size
packed_latent = list()
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
latent = latent[:, : h * p, : w * p].reshape(self.latent_channel, h, p, w, p)
latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
packed_latent.append(latent)
packed_latent = torch.cat(packed_latent, dim=0)
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = self.time_embedder(packed_timesteps)
packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
if packed_latent.dtype != packed_sequence.dtype:
packed_latent = packed_latent.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = packed_latent

extra_inputs = {}
if self.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes,
}

output = self.language_model.forward(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values

return past_key_values

def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
packed_vit_token_indexes = list()
vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()

_curr = curr = 0
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen

packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1

image_tensor = transforms(image)
vit_position_ids = self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.vit_patch_size,
max_num_patches_per_side=self.vit_max_num_patch_per_side,
)
vit_tokens = patchify(image_tensor, self.vit_patch_size)
packed_vit_tokens.append(vit_tokens)
num_img_tokens = vit_tokens.shape[0]
packed_vit_position_ids.append(vit_position_ids)
vit_token_seqlens.append(num_img_tokens)
packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens

packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1

packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)

generation_input = {
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
"packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
"packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
"packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}

return generation_input, newlens, new_rope

@torch.no_grad
def forward_cache_update_vit(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_vit_tokens: torch.Tensor,
packed_vit_token_indexes: torch.LongTensor,
packed_vit_position_ids: torch.LongTensor,
vit_token_seqlens: torch.IntTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
packed_sequence[packed_text_indexes] = packed_text_embedding

cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
cu_seqlens = cu_seqlens.to(torch.int32)
max_seqlen = torch.max(vit_token_seqlens).item()
packed_vit_token_embed = self.vit_model(
packed_pixel_values=packed_vit_tokens,
packed_flattened_position_ids=packed_vit_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
packed_vit_token_embed = self.connector(packed_vit_token_embed)
pos_emb = self.vit_pos_embed(packed_vit_position_ids)
packed_vit_token_embed = packed_vit_token_embed + pos_emb
if packed_vit_token_embed.dtype != packed_sequence.dtype:
packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed

extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}

output = self.language_model.forward(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values

return past_key_values

def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None):
packed_text_ids, packed_text_indexes = list(), list()
packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
Expand Down
Loading