Skip to content

KVCache in JanusFlow #218

@TtuHamg

Description

@TtuHamg

I noticed that the text-to-image demo provided by JanusFlow is not functioning properly. The issue lies in the fact that past_key_values is expected to be passed as a Cache object rather than a tuple. This problem has also been identified in related issues (#137 and #77), but the proposed solutions there do not seem to adopt the KVCache approach.

Based on my understanding, I implemented a version using KVCache. I'm still a beginner, so I'm not sure if this implementation is correct. For now, the output appears to be the same as when running without KVCache.

@torch.inference_mode()
def generate(
    mmgpt: MultiModalityCausalLM,
    vl_chat_processor: VLChatProcessor,
    prompt: str,
    cfg_weight: float = 5.0,
    num_inference_steps: int = 30,
    batchsize: int = 5
):
    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids)
    
    tokens = torch.stack([input_ids] * 2 * batchsize).cuda()
    tokens[batchsize:, 1:] = vl_chat_processor.pad_id
    inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)

    # we remove the last <bog> token and replace it with t_emb later
    inputs_embeds = inputs_embeds[:, :-1, :] 
    
    prompt_len = inputs_embeds.shape[1]
    prompt_attention_mask = torch.ones((2*batchsize, prompt_len), dtype=torch.int, device=vl_gpt.device)

    prompt_outputs = vl_gpt.language_model.model(
        inputs_embeds=inputs_embeds,
        use_cache=True,
        attention_mask=prompt_attention_mask,
        past_key_values=None
    )
    prompt_cache = DynamicCache.from_legacy_cache(prompt_outputs.past_key_values)
    
    # generate with rectified flow ode
    # step 1: encode with vision_gen_enc
    torch.manual_seed(42)
    z = torch.randn((batchsize, 4, 48, 48), dtype=torch.bfloat16).cuda()
    
    dt = 1.0 / num_inference_steps
    dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt 
    
    # step 2: run ode
    attention_mask = torch.ones((2*batchsize, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
    attention_mask[batchsize:, 1:inputs_embeds.shape[1]] = 0
    attention_mask = attention_mask.int()
    
    for step in range(num_inference_steps):
        # prepare inputs for the llm
        z_input = torch.cat([z, z], dim=0) 
        t = step / num_inference_steps * 1000.
        t = torch.tensor([t] * z_input.shape[0]).to(dt)
        z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
        z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
        z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
        z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)

        llm_emb_dyn = torch.cat([t_emb.unsqueeze(1), z_emb], dim=1) 

        outputs = vl_gpt.language_model.model(
            inputs_embeds=llm_emb_dyn,
            attention_mask=attention_mask,
            use_cache=True,
            past_key_values=prompt_cache 
        )
        
        past_key_values = []
        for kv_cache in outputs.past_key_values:
            k, v = kv_cache[0], kv_cache[1]
            past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
        prompt_cache = DynamicCache.from_legacy_cache(past_key_values)
        
        hidden_states = outputs.last_hidden_state
        
        # transform hidden_states back to v
        hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
        hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
        v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
        v_cond, v_uncond = torch.chunk(v, 2)
        v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
        z = z + dt * v
        
    decoded_image = vae.decode(z / vae.config.scaling_factor).sample
    
    os.makedirs('generated_samples', exist_ok=True)
    save_path = os.path.join('generated_samples', "img2.jpg")
    torchvision.utils.save_image(decoded_image.clip_(-1.0, 1.0)*0.5+0.5, save_path)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions