|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import os |
| 17 | +import re |
17 | 18 | from shutil import copyfile |
18 | 19 | from typing import Any, Dict, List, Optional, Tuple |
19 | 20 |
|
@@ -310,3 +311,33 @@ def create_token_type_ids_from_sequences( |
310 | 311 | output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) |
311 | 312 |
|
312 | 313 | return output |
| 314 | + |
| 315 | + def _extract_non_learnable_parts(self, origin_msg: List[Dict[str, str]], split_s: List[str]): |
| 316 | + regex_pattern = "|".join(map(re.escape, split_s)) |
| 317 | + rendered_messages = self.chat_template.render( |
| 318 | + messages=origin_msg, add_generation_prompt=False, **self.special_tokens_map |
| 319 | + ) |
| 320 | + pattern = re.compile(r"(?:%s)" % regex_pattern) |
| 321 | + split_positions = [match.span() for match in pattern.finditer(rendered_messages)] |
| 322 | + |
| 323 | + filtered_positions = [] |
| 324 | + for start, end in split_positions: |
| 325 | + # Find the last occurrence of '<start_of_turn>' before the split index |
| 326 | + last_start = rendered_messages.rfind("<start_of_turn>", 0, start) |
| 327 | + if last_start == -1: |
| 328 | + continue # Skip if '<start_of_turn>' is not found |
| 329 | + model_start = last_start + len("<start_of_turn>") |
| 330 | + |
| 331 | + # Get the text following 'model_start' and check if it starts with 'model' |
| 332 | + following_text = rendered_messages[model_start:].lstrip() |
| 333 | + if following_text.startswith("model"): |
| 334 | + filtered_positions.append((start, end)) |
| 335 | + non_learnable_parts = [] |
| 336 | + last_end = 0 |
| 337 | + for start, end in filtered_positions: |
| 338 | + non_learnable_parts.append(rendered_messages[last_end:start]) |
| 339 | + last_end = end |
| 340 | + remaining_part = rendered_messages[last_end:] |
| 341 | + if remaining_part: |
| 342 | + non_learnable_parts.append(remaining_part) |
| 343 | + return non_learnable_parts |
0 commit comments