Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5d6c5e1
llavas
zucchini-nlp Jan 15, 2025
b500dcf
add mroe models
zucchini-nlp Jan 16, 2025
b56b40c
fix `compile_forward` test for all models
zucchini-nlp Jan 16, 2025
040a83c
fix copies
zucchini-nlp Jan 16, 2025
4c8e6ab
make style
zucchini-nlp Jan 16, 2025
b72d845
also doesn't support cache class
zucchini-nlp Jan 16, 2025
70a0510
fix some tests
zucchini-nlp Jan 16, 2025
17b0c8f
not copied from
zucchini-nlp Jan 16, 2025
8ddee32
ci green?
zucchini-nlp Jan 17, 2025
370c9d2
fix tests
zucchini-nlp Jan 17, 2025
91d268d
Merge remote-tracking branch 'upstream/main' into compile-llava-enable
zucchini-nlp Jan 30, 2025
fcc6454
fix copies
zucchini-nlp Jan 30, 2025
41b50d8
fix tests
zucchini-nlp Jan 30, 2025
2b602ba
check with `numel` and remove `item`
zucchini-nlp Feb 10, 2025
4a3ff89
merge main
zucchini-nlp Feb 10, 2025
4e9cd52
fix copies
zucchini-nlp Feb 10, 2025
1776f0f
fix copies
zucchini-nlp Feb 10, 2025
e906616
Merge remote-tracking branch 'upstream/main' into compile-llava-enable
zucchini-nlp Feb 10, 2025
2232f62
Update src/transformers/models/cohere2/modeling_cohere2.py
zucchini-nlp Feb 13, 2025
f84242e
merge main
zucchini-nlp Feb 13, 2025
e089e34
opt remove cross attn
zucchini-nlp Feb 13, 2025
210bb5f
gemma2
zucchini-nlp Feb 13, 2025
7271490
fixup
zucchini-nlp Feb 13, 2025
496fc05
Merge branch 'main' into compile-llava-enable
zucchini-nlp Feb 13, 2025
45ad329
fixup
zucchini-nlp Feb 13, 2025
7a79bac
Merge branch 'main' into compile-llava-enable
zucchini-nlp Feb 14, 2025
2f219eb
fix newly added test
zucchini-nlp Feb 14, 2025
0cf1cfe
maybe fixed?
zucchini-nlp Feb 14, 2025
eccc5fa
green please?
zucchini-nlp Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,6 +2013,9 @@ def forward(
class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
config_class = Blip2Config
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: Blip2Config):
super().__init__(config)
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1283,12 +1284,13 @@ def forward(

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
if not is_torchdynamo_compiling():
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
n_image_features = image_tokens.shape[0] * image_tokens.shape[1]
if n_image_tokens_in_text != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
Expand Down Expand Up @@ -700,7 +700,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, HybridCache) or isinstance(past_key_values, StaticCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch.nn as nn

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
Expand Down Expand Up @@ -712,7 +712,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, HybridCache) or isinstance(past_key_values, StaticCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.utils.checkpoint

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...configuration_utils import PretrainedConfig
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
Expand Down Expand Up @@ -545,7 +545,7 @@ def _update_causal_mask(

dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
if isinstance(past_key_values, HybridCache) or isinstance(past_key_values, StaticCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)

query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
query = query.reshape(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what drove the change?

Copy link
Member Author

@zucchini-nlp zucchini-nlp Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They were not contiguous when compiling, probably it wasn't noticed because model is super low usage

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, tho calling .continguous works as well

Copy link
Member Author

@zucchini-nlp zucchini-nlp Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oke, I don't think it makes huge diff which one we use, can call contiguous just after RoPE


# [batch_size * num_heads, q_length, kv_length]
attn_scores = torch.zeros(
Expand Down
40 changes: 14 additions & 26 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,7 @@ def forward(
router_logits=all_router_logits,
)

# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
Expand All @@ -1117,13 +1118,8 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None

Expand All @@ -1144,7 +1140,6 @@ def _update_causal_mask(
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
Expand All @@ -1155,25 +1150,17 @@ def _update_causal_mask(
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
Expand All @@ -1183,6 +1170,7 @@ def _update_causal_mask(
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/instructblip/modeling_instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,9 @@ def forward(
class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
config_class = InstructBlipConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: InstructBlipConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,9 @@ def forward(
class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin):
config_class = InstructBlipVideoConfig
main_input_name = "pixel_values"
_supports_cache_class = True
_supports_static_cache = True
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)

def __init__(self, config: InstructBlipVideoConfig):
super().__init__(config)
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/llava/configuration_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ class LlavaConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
Expand Down Expand Up @@ -83,7 +81,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
Expand All @@ -92,7 +89,6 @@ def __init__(
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
Expand Down
99 changes: 10 additions & 89 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -136,6 +137,8 @@ class LlavaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
# important: this ported version of Llava isn't meant for training from scratch - only
Expand Down Expand Up @@ -321,89 +324,6 @@ def get_image_features(
image_features = self.multi_modal_projector(selected_image_feature)
return image_features

def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)

# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]

# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
image_to_overwrite &= padding_mask

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)

# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]

final_embedding[batch_indices, indices_to_mask] = 0

if labels is None:
final_labels = None

return final_embedding, final_attention_mask, final_labels, position_ids

@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -498,12 +418,13 @@ def forward(
image_sizes=image_sizes,
)

n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
if not is_torchdynamo_compiling():
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class LlavaNextConfig(PretrainedConfig):
The config object or dictionary of the vision backbone.
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
The config object or dictionary of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
Expand Down Expand Up @@ -88,7 +86,6 @@ def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
Expand All @@ -99,7 +96,6 @@ def __init__(
multimodal_projector_bias=True,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.image_seq_length = image_seq_length
Expand Down
Loading