-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Description
I noticed that the text-to-image demo provided by JanusFlow is not functioning properly. The issue lies in the fact that past_key_values is expected to be passed as a Cache object rather than a tuple. This problem has also been identified in related issues (#137 and #77), but the proposed solutions there do not seem to adopt the KVCache approach.
Based on my understanding, I implemented a version using KVCache. I'm still a beginner, so I'm not sure if this implementation is correct. For now, the output appears to be the same as when running without KVCache.
@torch.inference_mode()
def generate(
mmgpt: MultiModalityCausalLM,
vl_chat_processor: VLChatProcessor,
prompt: str,
cfg_weight: float = 5.0,
num_inference_steps: int = 30,
batchsize: int = 5
):
input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
tokens = torch.stack([input_ids] * 2 * batchsize).cuda()
tokens[batchsize:, 1:] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
# we remove the last <bog> token and replace it with t_emb later
inputs_embeds = inputs_embeds[:, :-1, :]
prompt_len = inputs_embeds.shape[1]
prompt_attention_mask = torch.ones((2*batchsize, prompt_len), dtype=torch.int, device=vl_gpt.device)
prompt_outputs = vl_gpt.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
attention_mask=prompt_attention_mask,
past_key_values=None
)
prompt_cache = DynamicCache.from_legacy_cache(prompt_outputs.past_key_values)
# generate with rectified flow ode
# step 1: encode with vision_gen_enc
torch.manual_seed(42)
z = torch.randn((batchsize, 4, 48, 48), dtype=torch.bfloat16).cuda()
dt = 1.0 / num_inference_steps
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
# step 2: run ode
attention_mask = torch.ones((2*batchsize, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
attention_mask[batchsize:, 1:inputs_embeds.shape[1]] = 0
attention_mask = attention_mask.int()
for step in range(num_inference_steps):
# prepare inputs for the llm
z_input = torch.cat([z, z], dim=0)
t = step / num_inference_steps * 1000.
t = torch.tensor([t] * z_input.shape[0]).to(dt)
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
llm_emb_dyn = torch.cat([t_emb.unsqueeze(1), z_emb], dim=1)
outputs = vl_gpt.language_model.model(
inputs_embeds=llm_emb_dyn,
attention_mask=attention_mask,
use_cache=True,
past_key_values=prompt_cache
)
past_key_values = []
for kv_cache in outputs.past_key_values:
k, v = kv_cache[0], kv_cache[1]
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
prompt_cache = DynamicCache.from_legacy_cache(past_key_values)
hidden_states = outputs.last_hidden_state
# transform hidden_states back to v
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
v_cond, v_uncond = torch.chunk(v, 2)
v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
z = z + dt * v
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
os.makedirs('generated_samples', exist_ok=True)
save_path = os.path.join('generated_samples', "img2.jpg")
torchvision.utils.save_image(decoded_image.clip_(-1.0, 1.0)*0.5+0.5, save_path)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels