Skip to content

Commit 872aafa

Browse files
authored
[Tokenizer] Fix chat template for Gemma when answer is contained within question (#9462)
1 parent 5ceb930 commit 872aafa

2 files changed

Lines changed: 57 additions & 0 deletions

File tree

paddlenlp/transformers/gemma/tokenizer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import os
17+
import re
1718
from shutil import copyfile
1819
from typing import Any, Dict, List, Optional, Tuple
1920

@@ -310,3 +311,33 @@ def create_token_type_ids_from_sequences(
310311
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
311312

312313
return output
314+
315+
def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]):
316+
regex_pattern = "|".join(map(re.escape, split_s))
317+
rendered_messages = self.chat_template.render(
318+
messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map
319+
)
320+
pattern = re.compile(r"(?:%s)" % regex_pattern)
321+
split_positions = [match.span() for match in pattern.finditer(rendered_messages)]
322+
323+
filtered_positions = []
324+
for start, end in split_positions:
325+
# Find the last occurrence of '<start_of_turn>' before the split index
326+
last_start = rendered_messages.rfind("<start_of_turn>", 0, start)
327+
if last_start == -1:
328+
continue # Skip if '<start_of_turn>' is not found
329+
model_start = last_start + len("<start_of_turn>")
330+
331+
# Get the text following 'model_start' and check if it starts with 'model'
332+
following_text = rendered_messages[model_start:].lstrip()
333+
if following_text.startswith("model"):
334+
filtered_positions.append((start, end))
335+
non_learnable_parts = []
336+
last_end = 0
337+
for start, end in filtered_positions:
338+
non_learnable_parts.append(rendered_messages[last_end:start])
339+
last_end = end
340+
remaining_part = rendered_messages[last_end:]
341+
if remaining_part:
342+
non_learnable_parts.append(remaining_part)
343+
return non_learnable_parts

tests/transformers/gemma/test_tokenizer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,29 @@ def test_add_special_tokens(self):
223223
self.assertEqual(encoded, input_encoded + special_token_id)
224224
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
225225
self.assertTrue(special_token not in decoded)
226+
227+
def test_extract_non_learnable_parts(self):
228+
models_with_templates = ["google/gemma-2b-it", "google/gemma-7b-it"]
229+
dummy_conversastions = [
230+
["Q.", "A."],
231+
["Q.A.", "A."],
232+
["Q?", "A!"],
233+
]
234+
decode_outputs = [
235+
["<bos><start_of_turn>user\nQ.<end_of_turn>\n<start_of_turn>model\n", "A.<end_of_turn>\n"],
236+
["<start_of_turn>user\nQ.A.<end_of_turn>\n<start_of_turn>model\n", "A.<end_of_turn>\n"],
237+
["<start_of_turn>user\nQ?<end_of_turn>\n<start_of_turn>model\n", "A!<end_of_turn>\n"],
238+
]
239+
context_data = {}
240+
context_data["is_training"] = True
241+
for model_id in models_with_templates:
242+
tokenizer = GemmaTokenizer.from_pretrained(model_id)
243+
if tokenizer.chat_template is None:
244+
continue
245+
conversation_result: list[tuple[list[int], list[int]]] = tokenizer.encode_chat_inputs(
246+
dummy_conversastions,
247+
context_data=context_data,
248+
)
249+
for idx, round in enumerate(conversation_result["conversations"]):
250+
self.assertEquals(tokenizer.decode(round[0]), decode_outputs[idx][0])
251+
self.assertEquals(tokenizer.decode(round[1]), decode_outputs[idx][1])

0 commit comments

Comments
 (0)