Skip to content

Commit 55f4fdd

Browse files
manueldepradazucchini-nlp
authored andcommitted
Fixes huggingface#37219 : RecurrentGemma crashes for inputs longer than sliding window length (huggingface#37613)
* fix: RecurrentGemma crashes during inference for inputs longer than sliding window width * fix recurrentgemma tests; add long test bigger than context window
1 parent 5aa6cc0 commit 55f4fdd

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,8 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
767767
if attention_mask is not None:
768768
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
769769
if attention_mask.dim() == 2:
770+
# Crop the attention mask to the target length.
771+
attention_mask = attention_mask[:, -target_length:]
770772
mask_length = attention_mask.shape[-1]
771773
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
772774
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,12 @@ def test_initialization(self):
278278
@slow
279279
class RecurrentGemmaIntegrationTest(unittest.TestCase):
280280
input_text = ["Hello I am doing", "Hi today"]
281+
input_long_text = ['<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'] # fmt: skip
281282
model_id = "google/recurrentgemma-2b"
282283

283284
@require_read_token
284285
def test_2b_generate(self):
285-
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a very good day for you. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do. You will be able to do all the work you want to do.'] # fmt: skip
286+
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today is a new app that allows you to make money by watching videos.\n\nThe app is very simple to use and you can earn money by watching videos.\n\nThe app is available for both Android and iOS devices and you can download it from the Google Play Store or the App Store.\n\nOnce you have downloaded the app'] # fmt: skip
286287
model = AutoModelForCausalLM.from_pretrained(self.model_id, low_cpu_mem_usage=True).to(torch_device)
287288

288289
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
@@ -296,7 +297,7 @@ def test_2b_generate(self):
296297
self.assertEqual(output_text, EXPECTED_TEXTS)
297298

298299
tokenizer.padding_side = "left"
299-
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I am going to share with you the best <strong><em>free online video editing software</em></strong>.\n\n<h2><strong>Best Free Online Video Editing Software</strong></h2>\n\n<strong>1.</strong> <strong>Wondershare Filmora</strong>\n\nWondershare Filmora is a free online video editing software that is used to edit videos.'] # fmt: skip
300+
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking for some information on the topic. I am looking for some information on the impact of the internet on the society. I am looking for some information on the impact of the internet on the society. I am looking for some', 'Hi today I’m going to show you how to make a simple and easy to make a <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY</strong> <strong>DIY'] # fmt: skip
300301

301302
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
302303
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
@@ -316,7 +317,7 @@ def test_2b_generate(self):
316317
@require_read_token
317318
def test_2b_sample(self):
318319
set_seed(0)
319-
EXPECTED_TEXT = ['Where is Paris ?\n\nAnswer this question "yes" or "no": Could a person pass out in subzero temperatures?\n\nFor the sentence below, underline the pronoun in parentheses that agrees with its antecedent.\n\nExample 1. Mary and Pam will have the opportunity to prove (herself, $\\underline{\\text{themselves}}$) at the concert.\n\nThe waiters and the manager at the restaurant will do <em>(his, their)</em> best to assist you.\n\nA vocabulary word appears in italics in the short passage below. Think about how the word is used. Then write a definition for the vocabulary word.\n\nAfter a one-hour $'] # fmt: skip
320+
EXPECTED_TEXT = ['Where is Paris ?\n\nChoose the word or phrase that is closest in meaning to the word in capital letters.\n\nREDEEM\n(A) sort out\n(B) think over\n(C) turn in\n(D) take back\n\nWrite the correct word in the space next to each definition. Use each word only once.\n\nto badly damage\n\nOn the lines provided below, write <em>P</em> if the underlined word group is a phrase and <em>NP</em> if it is not a phrase. Example $\\underline{\\text{P}}$ 1. We have finally discovered the secret $\\underline{\\text{of delicious pizza. }}$'] # fmt: skip
320321
model = AutoModelForCausalLM.from_pretrained(self.model_id).to(torch_device)
321322

322323
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
@@ -329,13 +330,13 @@ def test_2b_sample(self):
329330
@require_bitsandbytes
330331
@require_read_token
331332
def test_model_2b_8bit(self):
332-
EXPECTED_TEXTS = ['<bos>Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "<bos>Hi today<pad><pad> I'm going to show you how to make a simple and easy to use <strong><em><u>"] # fmt: skip
333+
EXPECTED_TEXTS = ['Hello I am doing a project on the topic of "The impact of the internet on the society" and I am looking', "Hi today I'm going to show you how to make a simple and easy to make a simple and easy"] # fmt: skip
333334

334335
model = AutoModelForCausalLM.from_pretrained(
335336
"gg-hf/recurrent-gemma-2b-hf", device_map={"": torch_device}, load_in_8bit=True, torch_dtype=torch.bfloat16
336337
)
337338

338-
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
339+
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
339340
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
340341

341342
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
@@ -345,18 +346,28 @@ def test_model_2b_8bit(self):
345346

346347
@require_read_token
347348
def test_long_context(self):
348-
input_text = [
349-
'<bos><s>Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col.'
350-
]
351-
EXPECTED_GENERATION = [
352-
' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." "We are not aware of any video footage that could have been taken on board the plane," Delannoy said. "We are not aware of any video footage that could'
353-
]
349+
EXPECTED_GENERATION = [' Jean-Paul Delannoy told CNN that the BEA is "not aware of any video footage that could have been taken on board the plane." He added that the BEA is "not aware of any video footage that could have been taken on board the plane." The BEA is the French equivalent of the National Transportation Safety Board'] # fmt: skip
354350

355351
model = AutoModelForCausalLM.from_pretrained(
356352
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
357353
).to(torch_device)
358-
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
359-
inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
354+
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
355+
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
356+
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
357+
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
358+
print(output_text)
359+
self.assertEqual(output_text, EXPECTED_GENERATION)
360+
361+
@require_read_token
362+
def test_longer_than_window(self):
363+
EXPECTED_GENERATION = [" Robin's comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the"] # fmt: skip
364+
365+
model = AutoModelForCausalLM.from_pretrained(
366+
self.model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16
367+
).to(torch_device)
368+
model.config.attention_window_size = 256 # Make the attention window size shorter than the current prompt
369+
tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left")
370+
inputs = tokenizer(self.input_long_text, return_tensors="pt").to(torch_device)
360371
output = model.generate(**inputs, max_new_tokens=64, do_sample=False)
361372
output_text = tokenizer.batch_decode(output[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
362373
self.assertEqual(output_text, EXPECTED_GENERATION)

0 commit comments

Comments
 (0)