Skip to content

Conversation

@zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Jan 16, 2025

What does this PR do?

As per title, adds flags in VLMs when needed, removes test skips and makes sure VLMs are compile compatible. Also for BLIP models adds new cache format in OPT which is one of backbones. Now all official BLIP models can support static cache and thus compile

NOTE:

  • Tests with -k compile_forward and -k static_ were run for all models and are passing
  • Regarding executorch which I also checked, the model can be exported and run a forward pass. But the generation won't work and probably would need smth similar to what we do when exporting VLMs in ONNX. Still need to dig more into that in later PRs

How to run compile and export for VLMs:

import requests
from PIL import Image

import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers.generation import GenerationConfig
from transformers.cache_utils import StaticCache
from transformers.integrations.executorch import (
    TorchExportableModuleWithStaticCache,
    convert_and_export_with_cache,
)

model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype="float16", 
    device_map="cuda:0",
)
processor = AutoProcessor.from_pretrained(model_id)

conversation = [
    {

      "role": "user",
      "content": [
          {"type": "text", "text": "What are these?"},
          {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)

# Run with static cache which compiles the forward in decoding phase for you
output = model.generate(**inputs, max_new_tokens=20, cache_implementation="static")
print(processor.decode(output[0][2:], skip_special_tokens=True))



# Try to export with `torch.export`. NOTE: TorchExportableModuleWithStaticCache is not ready for VLMs
# and as mentioned above, VLMs might need to export 3 different modules as in ONNX. One for text embedding,
# one for vision backbone and one for the LM backbone with simple decoding token-by-token
max_generation_length = 1000
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cuda:0",
    torch_dtype="float16",
    attn_implementation="sdpa",
    generation_config=GenerationConfig(
        use_cache=True,
        cache_implementation="static",
        max_length=max_generation_length,
        cache_config={
            "batch_size": 1,
            "max_cache_len": max_generation_length,
        },
    ),
)

# Adapted from `TorchExportableModuleWithStaticCache` with minor changes
class TorchExportableModuleWithStaticCacheForVLM(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.static_cache = StaticCache(
            config=self.model.config.get_text_config(),
            batch_size=self.model.generation_config.cache_config.batch_size,
            max_cache_len=self.model.generation_config.cache_config.max_cache_len,
            dtype=self.model.dtype,
            device=model.device,
        )
        self.is_causal = any(("CausalLM" in arch or "ConditionalGeneration" in arch) for arch in self.model.config.architectures)
        if self.is_causal:
            causal_mask = torch.triu(
                torch.full(
                    (
                        self.model.generation_config.cache_config.batch_size,
                        1,
                        self.static_cache.max_cache_len,
                        self.static_cache.max_cache_len
                    ),
                    fill_value=torch.finfo(self.model.dtype).min,
                    dtype=self.model.dtype,
                    device=model.device,
                )
            )
            self.register_buffer("mask", causal_mask, persistent=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        cache_position: torch.Tensor,
        pixel_values: torch.Tensor,
    ):
        _, seqlen = input_ids.shape
        attn_mask = self.mask[:, :, cache_position, :] if self.is_causal else None
        outs = self.model(
            input_ids=input_ids,
            attention_mask=attn_mask,
            position_ids=cache_position.unsqueeze(0),
            pixel_values=pixel_values,
            cache_position=cache_position,
            past_key_values=self.static_cache,
            use_cache=True,
        )
        return outs.logits

cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.long, device=model.device)
export_inputs = {"input_ids": inputs.input_ids, "cache_position": cache_position, "pixel_values": inputs.pixel_values}

with torch.no_grad():
    exported_program = torch.export.export(
        TorchExportableModuleWithStaticCacheForVLM(model),
        args=(),
        kwargs=export_inputs,
        strict=True,
    )

torch.export.save(exported_program, "exported_llava.pt2")
exported_program = torch.export.load("exported_llava.pt2")
out = exported_program.module().forward(
    input_ids=inputs.input_ids,
    pixel_values=inputs.pixel_values,
    cache_position=cache_position,
)

Benchmark on "llava-hf/llava-onevision-qwen2-7b-ov-hf" using the same script we use for LLMs + dummy image in inputs

image

Fixes #29891

Comment on lines 162 to 165
if past_key_value is not None:
if not isinstance(past_key_value, EncoderDecoderCache):
curr_past_key_value = past_key_value
else:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know why but OPT model works as decoder-only but the attention is written as cross-attention (not used anywhere in codebase). So we need to support somehow BC while using the new DynamicCache

As a workaround I simply added a check on cache instance. Another possibility is to accept and return only the correct cache (self or cross attn) but that means all encoder-decoder models will need a change thus breaking BC

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very much copy-paste from somewhere else, see this comment

IMO, we can make our maintenance easier and assume no encoder-decoder stuff :) But don't spend more time here, eventually this will be rewritten with modular

@zucchini-nlp zucchini-nlp changed the title [WIP] VLM: compile compatibility VLM: compile compatibility Jan 16, 2025
@zucchini-nlp zucchini-nlp changed the title VLM: compile compatibility [WIP] VLM: compile compatibility Jan 16, 2025
@zucchini-nlp zucchini-nlp changed the title [WIP] VLM: compile compatibility VLM: compile compatibility Jan 17, 2025
@zucchini-nlp
Copy link
Member Author

Ready for review failing test is flaky otherwise everything is passing on my end, including slow test for compile/StaticCache

@zucchini-nlp zucchini-nlp requested review from gante and removed request for Rocketknight1 and molbap January 30, 2025 14:22
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Contributor

gante commented Jan 30, 2025

@zucchini-nlp

In the PR header we can read

all VLMs have dynamic control in prepare_inputs_for_generation and thus skip test_compile_forward which compiles the model for pre-fill phase. But the test for decoding stage compile is green therefore I'm leaving the flag as True

test_compile_forward was the old name for test_generate_compile_model_forward if I'm not mistaken, back when it also did end-to-end compilation tests. We no longer have the end-to-end compilation tests, so this part of the PR header is no longer accurate, correct?

@gante
Copy link
Contributor

gante commented Jan 30, 2025

The failing test seems related to this PR :D

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice but 🔴 there are a few breaking changes so let's be careful!
And do you have some benches / perf imporvements to share (making sure reduce overhead is working etc_

Comment on lines 190 to 191
query = query.reshape(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, tho calling .continguous works as well

@zucchini-nlp zucchini-nlp changed the title VLM: compile compatibility 🔴 VLM: compile compatibility Feb 10, 2025
@zucchini-nlp
Copy link
Member Author

image

Soooo, here is the correct eval with llava-ov-7b. One thing to note is that VLMs will not benefit from torch when the context length is very high, like in case of videos or high-res images. That's the reason vanilla llava got speed up from first try, while llava-ov needed a few runs to notice how many input tokens we had

I will make fixup and merge this, because VLMs with less tokens per image get speedups

@zucchini-nlp zucchini-nlp merged commit 0c78ef6 into huggingface:main Feb 14, 2025
25 checks passed
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Feb 18, 2025
Since the latest transformers release of v4.49.0, X-LoRA tests are
broken. The PR that caused it was:

huggingface/transformers#35724

For the time being, let's skip the X-LoRA tests if this transformers
version is detected and also advice users against using X-LoRA with this
transformers version.
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Feb 18, 2025
X-LoRA tests started failing after this transformers PR:

huggingface/transformers#35724

The solution appears to be to disable caching completely when calling
generate on the X-LoRA model. This also makes some previously xfail-ing
tests pass.

I tested this locally with transformers checked out before and after the
mentioned PR and the tests pass in both circumstances. I also tested
changing the base model from "facebook/opt-125m" to
"trl-internal-testing/tiny-random-LlamaForCausalLM" and the tests passed
with both.
BenjaminBossan added a commit to huggingface/peft that referenced this pull request Feb 18, 2025
X-LoRA tests started failing after this transformers PR:

huggingface/transformers#35724

The solution appears to be to disable caching completely when calling
generate on the X-LoRA model. This also makes some previously xfail-ing
tests pass.

I tested this locally with transformers checked out before and after the
mentioned PR and the tests pass in both circumstances. I also tested
changing the base model from "facebook/opt-125m" to
"trl-internal-testing/tiny-random-LlamaForCausalLM" and the tests passed
with both.

Also, mark X-LoRA save_load_function test as flaky.
It was marked as xfail beforehand, but it is in fact just flaky.
zucchini-nlp added a commit to zucchini-nlp/transformers that referenced this pull request Feb 21, 2025
* llavas

* add mroe models

* fix `compile_forward` test for all models

* fix copies

* make style

* also doesn't support cache class

* fix some tests

* not copied from

* ci green?

* fix tests

* fix copies

* fix tests

* check with `numel` and remove `item`

* fix copies

* fix copies

* Update src/transformers/models/cohere2/modeling_cohere2.py

Co-authored-by: Arthur <[email protected]>

* opt remove cross attn

* gemma2

* fixup

* fixup

* fix newly added test

* maybe fixed?

* green please?

---------

Co-authored-by: Arthur <[email protected]>
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
X-LoRA tests started failing after this transformers PR:

huggingface/transformers#35724

The solution appears to be to disable caching completely when calling
generate on the X-LoRA model. This also makes some previously xfail-ing
tests pass.

I tested this locally with transformers checked out before and after the
mentioned PR and the tests pass in both circumstances. I also tested
changing the base model from "facebook/opt-125m" to
"trl-internal-testing/tiny-random-LlamaForCausalLM" and the tests passed
with both.

Also, mark X-LoRA save_load_function test as flaky.
It was marked as xfail beforehand, but it is in fact just flaky.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LLaVA torch.compile implementation

4 participants