Skip to content

Commit 19d58d3

Browse files
ArthurZuckerzucchini-nlppavelqubvelmolbap
authored
Add MLLama (huggingface#33703)
* current changes * nit * Add cross_attenttion_mask to processor * multi-image fixed * Add cross_attenttion_mask to processor * cross attn works in all cases * WIP refactoring function for image processor * WIP refactoring image processor functions * Refactor preprocess to use global loops instead of list nested list comps * Docstrings * Add channels unification * fix dtype issues * Update docsrings and format * Consistent max_image_tiles * current script * updates * Add convert to rgb * Add image processor tests * updates! * update * god damn it I am dumb sometimes * Precompute aspect ratios * now this works, full match * fix 😉 * nits * style * fix model and conversion * nit * nit * kinda works * hack for sdpa non-contiguous bias * nits here and there * latest c hanges * merge? * run forward * Add aspect_ratio_mask * vision attention mask * update script and config variable names * nit * nits * be able to load * style * nits * there * nits * make forward run * small update * enable generation multi-turn * nit * nit * Clean up a bit for errors and typos * A bit more constant fixes * 90B keys and shapes match * Fix for 11B model * Fixup, remove debug part * Docs * Make max_aspect_ratio_id to be minimal * Update image processing code to match new implementation * Adjust conversion for final checkpoint state * Change dim in repeat_interleave (accordig to meta code) * tmp fix for num_tiles * Fix for conversion (gate<->up, q/k_proj rope permute) * nits * codestyle * Vision encoder fixes * pass cross attn mask further * Refactor aspect ratio mask * Disable text-only generation * Fix cross attention layers order, remove q/k norm rotation for cross atention layers * Refactor gated position embeddings * fix bugs but needs test with new weights * rope scaling should be llama3 * Fix rope scaling name * Remove debug for linear layer * fix copies * Make mask prepare private func * Remove linear patch embed * Make precomputed embeddings as nn.Embedding module * MllamaPrecomputedAspectRatioEmbedding with config init * Remove unused self.output_dim * nit, intermediate layers * Rename ln and pos_embed * vision_chunk_size -> image_size * return_intermediate -> intermediate_layers_indices * vision_input_dim -> hidden_size * Fix copied from statements * fix most tests * Fix more copied from * layer_id->layer_idx * Comment * Fix tests for processor * Copied from for _prepare_4d_causal_attention_mask_with_cache_position * Style fix * Add MllamaForCausalLM * WIP fixing tests * Remove duplicated layers * Remove dummy file * Fix style * Fix consistency * Fix some TODOs * fix language_model instantiation, add docstring * Move docstring, remove todos for precomputed embeds (we cannot init them properly) * Add initial docstrings * Fix * fix some tests * lets skip these * nits, remove print, style * Add one more copied from * Improve test message * Make validate func private * Fix dummy objects * Refactor `data_format` a bit + add comment * typos/nits Co-authored-by: Pablo Montalvo <[email protected]> * fix dummy objects and imports * Add chat template config json * remove num_kv_heads from vision attention * fix * move some commits and add more tests * fix test * Remove `update_key_name` from modeling utils * remove num-kv-heads again * some prelimiary docs * Update chat template + tests * nit, conversion script max_num_tiles from params * Fix warning for text-only generation * Update conversion script for instruct models * Update chat template in converstion + test * add tests for CausalLM model * model_max_length, avoid null chat_template * Refactor conversion script * Fix forward * Fix integration tests * Refactor vision config + docs * Fix default * Refactor text config * Doc fixes * Remove unused args, fix docs example * Squashed commit of the following: commit b51ce5a2efffbecdefbf6fc92ee87372ec9d8830 Author: qubvel <[email protected]> Date: Wed Sep 18 13:39:15 2024 +0000 Move model + add output hidden states and output attentions * Fix num_channels * Add mllama text and mllama vision models * Fixing repo consistency * Style fix * Fixing repo consistency * Fixing unused config params * Fix failed tests after refactoring * hidden_activation -> hidden_act for text mlp * Remove from_pretrained from sub-configs * Apply suggestions from code review Co-authored-by: Arthur <[email protected]> * Update src/transformers/models/mllama/convert_mllama_weights_to_hf.py Co-authored-by: Arthur <[email protected]> * Reuse lambda in conversion script * Remove run.py * Update docs/source/en/model_doc/mllama.md Co-authored-by: Arthur <[email protected]> * Update src/transformers/models/mllama/processing_mllama.py Co-authored-by: Arthur <[email protected]> * Remove unused LlamaTokenizerFast * Fix logging * Refactor gating * Remove cycle for collecting intermediate states * Refactor text-only check, add integration test for text-only * Revert from pretrained to configs * Fix example * Add auto `bos_token` adding in processor * Fix tips * Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: Arthur <[email protected]> * Enable supports_gradient_checkpointing model flag * add eager/sdpa options * don't skip attn tests and bring back GC skips (did i really remove those?) * Fix signature, but get error with None gradient * Fix output attention tests * Disable GC back * Change no split modules * Fix dropout * Style * Add Mllama to sdpa list * Add post init for vision model * Refine config for MllamaForCausalLMModelTest and skipped tests for CausalLM model * if skipped, say it, don't pass * Clean vision tester config * Doc for args * Update tests/models/mllama/test_modeling_mllama.py Co-authored-by: Arthur <[email protected]> * Add cross_attention_mask to test * typehint * Remove todo * Enable gradient checkpointing * Docstring * Style * Fixing and skipping some tests for new cache * Mark flaky test * Skip `test_sdpa_can_compile_dynamic` test * Fixing some offload tests * Add direct GenerationMixin inheritance * Remove unused code * Add initializer_range to vision config * update the test to make sure we show if split * fix gc? * Fix repo consistency * Undo modeling utils debug changes * Fix link * mllama -> Mllama * [mllama] -> [Mllama] * Enable compile test for CausalLM model (text-only) * Fix TextModel prefix * Update doc * Docs for forward, type hints, and vision model prefix * make sure to reset * fix init * small script refactor and styling * nit * updates! * some nits * Interpolate embeddings for 560 size and update integration tests * nit * does not suppor static cache! * update * fix * nit2 * this? * Fix conversion * Style * 4x memory improvement with image cache AFAIK * Token decorator for tests * Skip failing tests * update processor errors * fix split issues * style * weird * style * fix failing tests * update * nit fixing the whisper tests * fix path * update --------- Co-authored-by: raushan <[email protected]> Co-authored-by: pavel <[email protected]> Co-authored-by: qubvel <[email protected]> Co-authored-by: Pablo Montalvo <[email protected]> Co-authored-by: ydshieh <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 94f18cf commit 19d58d3

31 files changed

+6183
-98
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,8 @@
860860
title: MatCha
861861
- local: model_doc/mgp-str
862862
title: MGP-STR
863+
- local: model_doc/mllama
864+
title: mllama
863865
- local: model_doc/nougat
864866
title: Nougat
865867
- local: model_doc/omdet-turbo

docs/source/en/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow.
214214
| [Mimi](model_doc/mimi) ||||
215215
| [Mistral](model_doc/mistral) ||||
216216
| [Mixtral](model_doc/mixtral) ||||
217+
| [Mllama](model_doc/mllama) ||||
217218
| [mLUKE](model_doc/mluke) ||||
218219
| [MMS](model_doc/mms) ||||
219220
| [MobileBERT](model_doc/mobilebert) ||||

docs/source/en/model_doc/mllama.md

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Mllama
18+
19+
## Overview
20+
21+
The Llama 3.2-Vision collection of multimodal large language models (LLMs) is a collection of pretrained and instruction-tuned image reasoning generative models in 11B and 90B sizes (text \+ images in / text out). The Llama 3.2-Vision instruction-tuned models are optimized for visual recognition, image reasoning, captioning, and answering general questions about an image.
22+
23+
**Model Architecture:** Llama 3.2-Vision is built on top of Llama 3.1 text-only model, which is an auto-regressive language model that uses an optimized transformer architecture. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align with human preferences for helpfulness and safety. To support image recognition tasks, the Llama 3.2-Vision model uses a separately trained vision adapter that integrates with the pre-trained Llama 3.1 language model. The adapter consists of a series of cross-attention layers that feed image encoder representations into the core LLM.
24+
25+
## Usage Tips
26+
27+
- For image+text and text inputs use `MllamaForConditionalGeneration`.
28+
- For text-only inputs use `MllamaForCausalLM` for generation to avoid loading vision tower.
29+
- Each sample can contain multiple images, and the number of images can vary between samples. The processor will pad the inputs to the maximum number of images across samples and to a maximum number of tiles within each image.
30+
- The text passed to the processor should have the `"<|image|>"` tokens where the images should be inserted.
31+
- The processor has its own `apply_chat_template` method to convert chat messages to text that can then be passed as text to the processor.
32+
33+
## Usage Example
34+
35+
#### Instruct model
36+
```python
37+
import requests
38+
import torch
39+
from PIL import Image
40+
from transformers import MllamaForConditionalGeneration, AutoProcessor
41+
42+
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
43+
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
44+
processor = AutoProcessor.from_pretrained(model_id)
45+
46+
messages = [
47+
[
48+
{
49+
"role": "user",
50+
"content": [
51+
{"type": "image"},
52+
{"type": "text", "text": "What does the image show?"}
53+
]
54+
}
55+
],
56+
]
57+
text = processor.apply_chat_template(messages, add_generation_prompt=True)
58+
59+
url = "https://llava-vl.github.io/static/images/view.jpg"
60+
image = Image.open(requests.get(url, stream=True).raw)
61+
62+
inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
63+
output = model.generate(**inputs, max_new_tokens=25)
64+
print(processor.decode(output[0]))
65+
```
66+
67+
#### Base model
68+
```python
69+
import requests
70+
import torch
71+
from PIL import Image
72+
from transformers import MllamaForConditionalGeneration, AutoProcessor
73+
74+
model_id = "meta-llama/Llama-3.2-11B-Vision"
75+
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
76+
processor = AutoProcessor.from_pretrained(model_id)
77+
78+
prompt = "<|image|>If I had to write a haiku for this one"
79+
url = "https://llava-vl.github.io/static/images/view.jpg"
80+
raw_image = Image.open(requests.get(url, stream=True).raw)
81+
82+
inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(model.device)
83+
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
84+
print(processor.decode(output[0], skip_special_tokens=True))
85+
```
86+
87+
88+
## MllamaConfig
89+
90+
[[autodoc]] MllamaConfig
91+
92+
## MllamaProcessor
93+
94+
[[autodoc]] MllamaProcessor
95+
96+
97+
## MllamaImageProcessor
98+
99+
[[autodoc]] MllamaImageProcessor
100+
101+
## MllamaForConditionalGeneration
102+
103+
[[autodoc]] MllamaForConditionalGeneration
104+
- forward
105+
106+
## MllamaForCausalLM
107+
108+
[[autodoc]] MllamaForCausalLM
109+
- forward
110+
111+
## MllamaTextModel
112+
113+
[[autodoc]] MllamaTextModel
114+
- forward
115+
116+
## MllamaForCausalLM
117+
118+
[[autodoc]] MllamaForCausalLM
119+
- forward
120+
121+
## MllamaVisionModel
122+
123+
[[autodoc]] MllamaVisionModel
124+
- forward

docs/source/en/perf_infer_gpu_one.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ For now, Transformers supports SDPA inference and training for the following arc
236236
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
237237
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
238238
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
239+
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
239240
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
240241
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
241242
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)

src/transformers/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,10 @@
577577
"models.mimi": ["MimiConfig"],
578578
"models.mistral": ["MistralConfig"],
579579
"models.mixtral": ["MixtralConfig"],
580+
"models.mllama": [
581+
"MllamaConfig",
582+
"MllamaProcessor",
583+
],
580584
"models.mluke": [],
581585
"models.mobilebert": [
582586
"MobileBertConfig",
@@ -1199,6 +1203,7 @@
11991203
)
12001204
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
12011205
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
1206+
_import_structure["models.mllama"].extend(["MllamaImageProcessor"])
12021207
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
12031208
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
12041209
_import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"])
@@ -2704,6 +2709,16 @@
27042709
"MixtralPreTrainedModel",
27052710
]
27062711
)
2712+
_import_structure["models.mllama"].extend(
2713+
[
2714+
"MllamaForCausalLM",
2715+
"MllamaForConditionalGeneration",
2716+
"MllamaPreTrainedModel",
2717+
"MllamaProcessor",
2718+
"MllamaTextModel",
2719+
"MllamaVisionModel",
2720+
]
2721+
)
27072722
_import_structure["models.mobilebert"].extend(
27082723
[
27092724
"MobileBertForMaskedLM",
@@ -5377,6 +5392,10 @@
53775392
)
53785393
from .models.mistral import MistralConfig
53795394
from .models.mixtral import MixtralConfig
5395+
from .models.mllama import (
5396+
MllamaConfig,
5397+
MllamaProcessor,
5398+
)
53805399
from .models.mobilebert import (
53815400
MobileBertConfig,
53825401
MobileBertTokenizer,
@@ -6037,6 +6056,7 @@
60376056
MaskFormerFeatureExtractor,
60386057
MaskFormerImageProcessor,
60396058
)
6059+
from .models.mllama import MllamaImageProcessor
60406060
from .models.mobilenet_v1 import (
60416061
MobileNetV1FeatureExtractor,
60426062
MobileNetV1ImageProcessor,
@@ -7270,6 +7290,14 @@
72707290
MixtralModel,
72717291
MixtralPreTrainedModel,
72727292
)
7293+
from .models.mllama import (
7294+
MllamaForCausalLM,
7295+
MllamaForConditionalGeneration,
7296+
MllamaPreTrainedModel,
7297+
MllamaProcessor,
7298+
MllamaTextModel,
7299+
MllamaVisionModel,
7300+
)
72737301
from .models.mobilebert import (
72747302
MobileBertForMaskedLM,
72757303
MobileBertForMultipleChoice,

src/transformers/cache_utils.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
8080
def reorder_cache(self, beam_idx: torch.LongTensor):
8181
"""Reorders the cache for beam search, given the selected beam indices."""
8282
for layer_idx in range(len(self.key_cache)):
83-
device = self.key_cache[layer_idx].device
84-
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
85-
device = self.value_cache[layer_idx].device
86-
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
83+
if self.key_cache[layer_idx] != []:
84+
device = self.key_cache[layer_idx].device
85+
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
86+
if self.value_cache[layer_idx] != []:
87+
device = self.value_cache[layer_idx].device
88+
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
8789

8890
@property
8991
def seen_tokens(self):
@@ -358,10 +360,14 @@ class DynamicCache(Cache):
358360
```
359361
"""
360362

361-
def __init__(self) -> None:
363+
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
362364
super().__init__()
363-
self.key_cache: List[torch.Tensor] = []
364-
self.value_cache: List[torch.Tensor] = []
365+
if num_hidden_layers is None:
366+
self.key_cache: List[torch.Tensor] = []
367+
self.value_cache: List[torch.Tensor] = []
368+
else:
369+
self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
370+
self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
365371
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
366372

367373
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
@@ -420,6 +426,11 @@ def update(
420426
if len(self.key_cache) <= layer_idx:
421427
self.key_cache.append(key_states)
422428
self.value_cache.append(value_states)
429+
# content on layer cache can be a tensor and checking not tensor causes errors
430+
# so we explicitly check for the empty list
431+
elif self.key_cache[layer_idx] == []:
432+
self.key_cache[layer_idx] = key_states
433+
self.value_cache[layer_idx] = value_states
423434
else:
424435
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
425436
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
@@ -429,7 +440,7 @@ def update(
429440
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
430441
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
431442
# TODO: deprecate this function in favor of `cache_position`
432-
if len(self.key_cache) <= layer_idx:
443+
if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
433444
return 0
434445
return self.key_cache[layer_idx].shape[-2]
435446

@@ -446,10 +457,12 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
446457
return legacy_cache
447458

448459
@classmethod
449-
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
460+
def from_legacy_cache(
461+
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
462+
) -> "DynamicCache":
450463
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
451464
backward compatibility."""
452-
cache = cls()
465+
cache = cls(num_hidden_layers)
453466
if past_key_values is not None:
454467
for layer_idx in range(len(past_key_values)):
455468
key_states, value_states = past_key_values[layer_idx]
@@ -468,30 +481,34 @@ def crop(self, max_length: int):
468481

469482
self._seen_tokens = max_length
470483
for idx in range(len(self.key_cache)):
471-
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
472-
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
484+
if self.key_cache[idx] != []:
485+
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
486+
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
473487

474-
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
488+
def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]:
475489
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
476490
`_split_model_inputs()` in `generation.utils`"""
477491
out = []
478492
for i in range(0, full_batch_size, split_size):
479-
current_split = DynamicCache()
493+
current_split = DynamicCache(num_hidden_layers)
480494
current_split._seen_tokens = self._seen_tokens
481495
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
482496
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
483497
out.append(current_split)
484498
return out
485499

486500
@classmethod
487-
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
501+
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache":
488502
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
489503
`generation.utils`"""
490-
cache = cls()
504+
cache = cls(num_hidden_layers)
491505
for idx in range(len(splits[0])):
492-
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
493-
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
494-
cache.update(layer_keys, layer_values, idx)
506+
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
507+
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
508+
if key_cache != []:
509+
layer_keys = torch.cat(key_cache, dim=0)
510+
layer_values = torch.cat(value_cache, dim=0)
511+
cache.update(layer_keys, layer_values, idx)
495512
return cache
496513

497514
def batch_repeat_interleave(self, repeats: int):
@@ -1391,10 +1408,13 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
13911408

13921409
@classmethod
13931410
def from_legacy_cache(
1394-
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
1411+
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
13951412
) -> "EncoderDecoderCache":
13961413
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
1397-
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
1414+
cache = cls(
1415+
self_attention_cache=DynamicCache(num_hidden_layers),
1416+
cross_attention_cache=DynamicCache(num_hidden_layers),
1417+
)
13981418
if past_key_values is not None:
13991419
for layer_idx in range(len(past_key_values)):
14001420
key_states, value_states = past_key_values[layer_idx][:2]
@@ -1407,7 +1427,10 @@ def from_legacy_cache(
14071427

14081428
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
14091429
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
1410-
if len(self.self_attention_cache.key_cache) <= layer_idx:
1430+
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
1431+
if self.self_attention_cache.key_cache == []:
1432+
return 0
1433+
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
14111434
return 0
14121435
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
14131436

@@ -1448,24 +1471,26 @@ def crop(self, maximum_length: int):
14481471
self.check_dynamic_cache(self.crop.__name__)
14491472
self.self_attention_cache.crop(maximum_length)
14501473

1451-
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
1474+
def batch_split(
1475+
self, full_batch_size: int, split_size: int, num_hidden_layers: int
1476+
) -> "List[EncoderDecoderCache]":
14521477
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
14531478
`_split_model_inputs()` in `generation.utils`"""
14541479
self.check_dynamic_cache(self.batch_split.__name__)
1455-
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
1456-
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
1480+
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
1481+
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
14571482

14581483
out = []
14591484
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
14601485
out.append(EncoderDecoderCache(self_attn, cross_attn))
14611486
return out
14621487

14631488
@classmethod
1464-
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
1489+
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache":
14651490
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
14661491
`generation.utils`"""
1467-
self_attention_cache = DynamicCache()
1468-
cross_attention_cache = DynamicCache()
1492+
self_attention_cache = DynamicCache(num_hidden_layers)
1493+
cross_attention_cache = DynamicCache(num_hidden_layers)
14691494
for idx in range(len(splits[0])):
14701495
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
14711496
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)

0 commit comments

Comments
 (0)