Skip to content

Commit ab36686

Browse files
committed
Fix preprocessing
1 parent fb3cd1d commit ab36686

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

specforge/data/parse.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ class GeneralParser(Parser):
4141
def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
4242
super().__init__(tokenizer, chat_template)
4343
self.system_prompt = chat_template.system_prompt
44-
self.user_message_separator = (
45-
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
46-
)
47-
self.assistant_message_separator = (
48-
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
49-
)
44+
if chat_template.end_of_turn_token:
45+
self.user_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}"
46+
self.assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}"
47+
else:
48+
self.user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}"
49+
self.assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}"
5050

5151
def parse(
5252
self, conversation: "Conversation", max_length: int, preformatted: bool = False

specforge/data/preprocessing.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,14 @@ def _apply_loss_mask_from_chat_template(
7171
"""
7272
loss_mask = torch.zeros(len(offsets), dtype=torch.long)
7373

74-
if chat_template.end_of_turn_token is not None:
74+
if chat_template.end_of_turn_token:
7575
user_message_separator = (
76-
f"{chat_template.end_of_turn_token}{chat_template.user_header}"
77-
)
78-
assistant_message_separator = (
79-
f"{chat_template.end_of_turn_token}{chat_template.assistant_header}"
76+
f"{chat_template.end_of_turn_token or ''}{chat_template.user_header or ''}"
8077
)
78+
assistant_message_separator = f"{chat_template.end_of_turn_token or ''}{chat_template.assistant_header or ''}"
8179
else:
82-
user_message_separator = (
83-
f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header}"
84-
)
85-
assistant_message_separator = (
86-
f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header}"
87-
)
80+
user_message_separator = f"{chat_template.end_of_assistant_token or ''}{chat_template.user_header or ''}"
81+
assistant_message_separator = f"{chat_template.end_of_user_token or ''}{chat_template.assistant_header or ''}"
8882

8983
# Find spans of assistant responses using regex
9084
assistant_pattern = (

0 commit comments

Comments
 (0)