File tree Expand file tree Collapse file tree 2 files changed +11
-17
lines changed Expand file tree Collapse file tree 2 files changed +11
-17
lines changed Original file line number Diff line number Diff line change @@ -41,12 +41,12 @@ class GeneralParser(Parser):
4141 def __init__ (self , tokenizer : PreTrainedTokenizer , chat_template : ChatTemplate ):
4242 super ().__init__ (tokenizer , chat_template )
4343 self .system_prompt = chat_template .system_prompt
44- self . user_message_separator = (
45- f"{ chat_template .end_of_turn_token } { chat_template .user_header } "
46- )
47- self . assistant_message_separator = (
48- f"{ chat_template .end_of_turn_token } { chat_template .assistant_header } "
49- )
44+ if chat_template . end_of_turn_token :
45+ self . user_message_separator = f"{ chat_template .end_of_turn_token or '' } { chat_template .user_header or '' } "
46+ self . assistant_message_separator = f" { chat_template . end_of_turn_token or '' } { chat_template . assistant_header or '' } "
47+ else :
48+ self . user_message_separator = f"{ chat_template .end_of_assistant_token or '' } { chat_template .user_header or '' } "
49+ self . assistant_message_separator = f" { chat_template . end_of_user_token or '' } { chat_template . assistant_header or '' } "
5050
5151 def parse (
5252 self , conversation : "Conversation" , max_length : int , preformatted : bool = False
Original file line number Diff line number Diff line change @@ -71,20 +71,14 @@ def _apply_loss_mask_from_chat_template(
7171 """
7272 loss_mask = torch .zeros (len (offsets ), dtype = torch .long )
7373
74- if chat_template .end_of_turn_token is not None :
74+ if chat_template .end_of_turn_token :
7575 user_message_separator = (
76- f"{ chat_template .end_of_turn_token } { chat_template .user_header } "
77- )
78- assistant_message_separator = (
79- f"{ chat_template .end_of_turn_token } { chat_template .assistant_header } "
76+ f"{ chat_template .end_of_turn_token or '' } { chat_template .user_header or '' } "
8077 )
78+ assistant_message_separator = f"{ chat_template .end_of_turn_token or '' } { chat_template .assistant_header or '' } "
8179 else :
82- user_message_separator = (
83- f"{ chat_template .end_of_assistant_token or '' } { chat_template .user_header } "
84- )
85- assistant_message_separator = (
86- f"{ chat_template .end_of_user_token or '' } { chat_template .assistant_header } "
87- )
80+ user_message_separator = f"{ chat_template .end_of_assistant_token or '' } { chat_template .user_header or '' } "
81+ assistant_message_separator = f"{ chat_template .end_of_user_token or '' } { chat_template .assistant_header or '' } "
8882
8983 # Find spans of assistant responses using regex
9084 assistant_pattern = (
You can’t perform that action at this time.
0 commit comments