Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/offline_inference/bagel/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def parse_args():
parser.add_argument("--stage-configs-path", type=str, default=None)
parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.")

parser.add_argument("--cfg-text-scale", type=float, default=4.0, help="Text CFG scale (default: 4.0)")
parser.add_argument("--cfg-img-scale", type=float, default=1.5, help="Image CFG scale (default: 1.5)")
parser.add_argument(
"--negative-prompt", type=str, default=None, help="Negative prompt (not yet supported, reserved for future)"
)

args = parser.parse_args()
return args

Expand Down Expand Up @@ -102,6 +108,10 @@ def main():
seed=52,
need_kv_receive=False,
num_inference_steps=args.steps,
extra_args={
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
},
),
)

Expand Down Expand Up @@ -158,7 +168,12 @@ def main():
if args.modality == "text2img":
params_list[0].max_tokens = 1 # type: ignore # The first stage is a SamplingParam (vllm)
if len(params_list) > 1:
params_list[1].num_inference_steps = args.steps # type: ignore # The second stage is an OmniDiffusionSamplingParam
diffusion_params = params_list[1]
diffusion_params.num_inference_steps = args.steps # type: ignore
diffusion_params.extra_args = { # type: ignore
"cfg_text_scale": args.cfg_text_scale,
"cfg_img_scale": args.cfg_img_scale,
}

omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))

Expand Down
177 changes: 157 additions & 20 deletions vllm_omni/diffusion/models/bagel/bagel_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import math
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -1251,6 +1250,39 @@ def prepare_input(self, curr_kvlens, curr_rope, image_sizes, new_token_ids=None)
def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
return self.prepare_input(curr_kvlens, curr_rope, image_sizes, new_token_ids)

def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()

query_curr = curr = 0
for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen

packed_indexes.append(curr)
curr += 1
query_curr += 1

h, w = H // self.latent_downsample, W // self.latent_downsample
num_image_tokens = h * w
packed_indexes.extend(range(curr, curr + num_image_tokens))
curr += num_image_tokens
query_curr += num_image_tokens

packed_indexes.append(curr)
curr += 1
query_curr += 1

packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))

generation_input = {
"cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
}

return generation_input

def generate_image(
self,
packed_text_ids: torch.LongTensor,
Expand All @@ -1266,11 +1298,25 @@ def generate_image(
packed_key_value_indexes: torch.LongTensor,
num_timesteps: int = 24,
timestep_shift: float = 1.0,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_interval: tuple[float, float] = [0, 1],
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_query_indexes: torch.LongTensor | None = None,
cfg_text_packed_position_ids: torch.LongTensor | None = None,
cfg_text_past_key_values: NaiveCache | None = None,
cfg_text_key_values_lens: torch.IntTensor | None = None,
cfg_text_packed_key_value_indexes: torch.LongTensor | None = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_query_indexes: torch.LongTensor | None = None,
cfg_img_packed_position_ids: torch.LongTensor | None = None,
cfg_img_past_key_values: NaiveCache | None = None,
cfg_img_key_values_lens: torch.IntTensor | None = None,
cfg_img_packed_key_value_indexes: torch.LongTensor | None = None,
cfg_type: str = "parallel",
):
model_pred_cache_dic, model_pred_current = None, None
model_pred_text_cache_dic, model_pred_text_current = None, None
model_pred_img_cache_dic, model_pred_img_current = None, None

x_t = packed_init_noises

timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
Expand All @@ -1280,6 +1326,12 @@ def generate_image(

for i, t in enumerate(timesteps):
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if t > cfg_interval[0] and t <= cfg_interval[1]:
cfg_text_scale_ = cfg_text_scale
cfg_img_scale_ = cfg_img_scale
else:
cfg_text_scale_ = 1.0
cfg_img_scale_ = 1.0
v_t = self._forward_flow(
x_t=x_t,
timestep=timestep,
Expand All @@ -1293,17 +1345,26 @@ def generate_image(
key_values_lens=key_values_lens,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
# cache
model_pred_cache_dic=model_pred_cache_dic,
model_pred_current=model_pred_current,
model_pred_text_cache_dic=model_pred_text_cache_dic,
model_pred_text_current=model_pred_text_current,
model_pred_img_cache_dic=model_pred_img_cache_dic,
model_pred_img_current=model_pred_img_current,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
# cfg_text
cfg_text_scale=cfg_text_scale_,
cfg_text_packed_position_ids=cfg_text_packed_position_ids,
cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
cfg_text_key_values_lens=cfg_text_key_values_lens,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
# cfg_img
cfg_img_scale=cfg_img_scale_,
cfg_img_packed_position_ids=cfg_img_packed_position_ids,
cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
cfg_img_key_values_lens=cfg_img_key_values_lens,
cfg_img_past_key_values=cfg_img_past_key_values,
cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
cfg_type=cfg_type,
)

x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise

unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
return unpacked_latent

Expand All @@ -1321,13 +1382,23 @@ def _forward_flow(
key_values_lens: torch.IntTensor,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
# cache
model_pred_cache_dic: dict[str, Any] | None = None,
model_pred_current: int | None = None,
model_pred_text_cache_dic: dict[str, Any] | None = None,
model_pred_text_current: int | None = None,
model_pred_img_cache_dic: dict[str, Any] | None = None,
model_pred_img_current: int | None = None,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_position_ids: torch.LongTensor | None = None,
cfg_text_packed_query_indexes: torch.LongTensor | None = None,
cfg_text_key_values_lens: torch.Tensor | None = None,
cfg_text_past_key_values: NaiveCache | None = None,
cfg_text_packed_key_value_indexes: torch.LongTensor | None = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_position_ids: torch.LongTensor | None = None,
cfg_img_packed_query_indexes: torch.LongTensor | None = None,
cfg_img_key_values_lens: torch.Tensor | None = None,
cfg_img_past_key_values: NaiveCache | None = None,
cfg_img_packed_key_value_indexes: torch.LongTensor | None = None,
cfg_type: str = "parallel",
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
Expand Down Expand Up @@ -1364,4 +1435,70 @@ def _forward_flow(
v_t = self.llm2vae(output.packed_query_sequence)
v_t = v_t[packed_vae_token_indexes]

if cfg_text_scale > 1.0:
cfg_text_output = self.language_model.forward(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_text_packed_position_ids,
packed_query_indexes=cfg_text_packed_query_indexes,
past_key_values=cfg_text_past_key_values,
key_values_lens=cfg_text_key_values_lens,
packed_key_value_indexes=cfg_text_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]

if cfg_img_scale > 1.0:
cfg_img_output = self.language_model.forward(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_img_packed_position_ids,
packed_query_indexes=cfg_img_packed_query_indexes,
past_key_values=cfg_img_past_key_values,
key_values_lens=cfg_img_key_values_lens,
packed_key_value_indexes=cfg_img_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]

if cfg_text_scale > 1.0:
if cfg_renorm_type == "text_channel":
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
v_t_text = v_t_text_ * scale
if cfg_img_scale > 1.0:
v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
else:
v_t = v_t_text
else:
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)

if cfg_img_scale > 1.0:
v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
else:
v_t_ = v_t_text_

# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
if cfg_renorm_type == "global":
norm_v_t = torch.norm(v_t)
norm_v_t_ = torch.norm(v_t_)
elif cfg_renorm_type == "channel":
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
else:
raise NotImplementedError(f"{cfg_renorm_type} is not supported")
scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
v_t = v_t_ * scale
else:
# No CFG
pass

return v_t
Loading