-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Description
System Info
System Info:
transformersversion: 4.50.3- Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
- Python version: 3.10.16
- Huggingface_hub version: 0.30.1
- Safetensors version: 0.5.3
- Accelerate version: 1.6.0
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (GPU?): 2.5.1+cu118 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: yes
- GPU type: NVIDIA RTX A6000
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Code snippet for reproduction:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-9b-it")
model = AutoModelForCausalLM.from_pretrained("google/recurrentgemma-9b-it", device_map="cuda", torch_dtype=torch.float16)
input_text = "Write me a poem about Machine Learning." * 300 # This string is 2402 tokens long, which is larger than 2048, the sliding window attention width
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=20)
print(tokenizer.decode(outputs[0]))
Error message:
Traceback (most recent call last):
File "/data2/assaf/tmp/test_rg.py", line 13, in
outputs = model.generate(**input_ids, max_new_tokens=20)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 2326, in generate
result = self._sample(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 3289, in _sample
outputs = model_forward(**model_inputs, return_dict=True)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 852, in forward
outputs = self.model(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 717, in forward
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 764, in _update_causal_mask
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
RuntimeError: The size of tensor a (2048) must match the size of tensor b (2402) at non-singleton dimension 3
Expected behavior
If the sequence is longer than the sliding window width (like it is now in the script) then the script crashes with the error message above.
If the sequence is shorter than the sliding window width (e.g. replace *300 by *200) then the script runs fine.
The bug was seen in transformers version v4.50.3
It does not reproduce on earlier transformers versions, such as v4.42.4