Skip to content

RecurrentGemma crashes during inference for inputs longer than sliding window width #37219

@assafbk

Description

@assafbk

System Info

System Info:

  • transformers version: 4.50.3
  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
  • Python version: 3.10.16
  • Huggingface_hub version: 0.30.1
  • Safetensors version: 0.5.3
  • Accelerate version: 1.6.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.5.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA RTX A6000

Who can help?

@ArthurZucker, @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Code snippet for reproduction:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-9b-it")
model = AutoModelForCausalLM.from_pretrained("google/recurrentgemma-9b-it", device_map="cuda", torch_dtype=torch.float16)

input_text = "Write me a poem about Machine Learning." * 300    # This string is 2402 tokens long, which is larger than 2048, the sliding window attention width
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=20)
print(tokenizer.decode(outputs[0]))

Error message:

Traceback (most recent call last):
File "/data2/assaf/tmp/test_rg.py", line 13, in
outputs = model.generate(**input_ids, max_new_tokens=20)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 2326, in generate
result = self._sample(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 3289, in _sample
outputs = model_forward(**model_inputs, return_dict=True)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 852, in forward
outputs = self.model(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 717, in forward
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 764, in _update_causal_mask
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
RuntimeError: The size of tensor a (2048) must match the size of tensor b (2402) at non-singleton dimension 3

Expected behavior

If the sequence is longer than the sliding window width (like it is now in the script) then the script crashes with the error message above.

If the sequence is shorter than the sliding window width (e.g. replace *300 by *200) then the script runs fine.

The bug was seen in transformers version v4.50.3
It does not reproduce on earlier transformers versions, such as v4.42.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions