@@ -45,35 +45,34 @@ def str_presenter(dumper, data):
4545yaml .representer .SafeRepresenter .add_representer (str , str_presenter )
4646
4747
48- def extract_message_prefix_suffix (renderer , strip_trailing_spaces = True ):
48+ def get_generation_prompt (renderer , impersonate = False , strip_trailing_spaces = True ):
4949 '''
50- Given a Jinja template, extracts the prefix and suffix for
51- an assistant message and a user message. It assumes that they
52- share the same suffix.
50+ Given a Jinja template, reverse-engineers the prefix and the suffix for
51+ an assistant message (if impersonate=False) or an user message
52+ (if impersonate=True)
5353 '''
5454
55- messages = [
56- {"role" : "user" , "content" : "<<|user-message-1|>>" },
57- {"role" : "assistant" , "content" : "<<|assistant-message-1|>>" },
58- {"role" : "user" , "content" : "<<|user-message-2|>>" },
59- {"role" : "assistant" , "content" : "<<|assistant-message-2|>>" },
60- ]
55+ if impersonate :
56+ messages = [
57+ {"role" : "user" , "content" : "<<|user-message-1|>>" },
58+ {"role" : "user" , "content" : "<<|user-message-2|>>" },
59+ ]
60+ else :
61+ messages = [
62+ {"role" : "assistant" , "content" : "<<|user-message-1|>>" },
63+ {"role" : "assistant" , "content" : "<<|user-message-2|>>" },
64+ ]
6165
6266 prompt = renderer (messages = messages )
63- unwanted_suffix = renderer (messages = [])
64-
65- suffix = prompt .split ('<<|assistant-message-2|>>' )[1 ]
66- if unwanted_suffix != '' :
67- suffix = suffix [:- len (unwanted_suffix )]
6867
69- prefix_user = prompt .split ('<<|assistant-message-1|>>' )[1 ].split ('<<|user-message-2|>>' )[0 ][len (suffix ):]
70- prefix_assistant = prompt .split ('<<|user-message-1|>>' )[1 ].split ('<<|assistant-message-1|>>' )[0 ][len (suffix ):]
68+ suffix_plus_prefix = prompt .split ("<<|user-message-1|>>" )[1 ].split ("<<|user-message-2|>>" )[0 ]
69+ suffix = prompt .split ("<<|user-message-2|>>" )[1 ]
70+ prefix = suffix_plus_prefix [len (suffix ):]
7171
7272 if strip_trailing_spaces :
73- prefix_user = prefix_user .rstrip (' ' )
74- prefix_assistant = prefix_assistant .rstrip (' ' )
73+ prefix = prefix .rstrip (' ' )
7574
76- return prefix_user , prefix_assistant , suffix
75+ return prefix , suffix
7776
7877
7978def generate_chat_prompt (user_input , state , ** kwargs ):
@@ -126,12 +125,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
126125 messages .append ({"role" : "user" , "content" : user_input })
127126
128127 def remove_extra_bos (prompt ):
129- if hasattr (shared .tokenizer , 'bos_token_id' ):
130- bos_tokens = [shared .tokenizer .decode (shared .tokenizer .bos_token_id )]
131- else :
132- bos_tokens = ['<s>' , '<|startoftext|>' , '<BOS_TOKEN>' ]
133-
134- for bos_token in bos_tokens :
128+ for bos_token in ['<s>' , '<|startoftext|>' , '<BOS_TOKEN>' , '<|endoftext|>' ]:
135129 while prompt .startswith (bos_token ):
136130 prompt = prompt [len (bos_token ):]
137131
@@ -143,9 +137,6 @@ def make_prompt(messages):
143137 else :
144138 prompt = renderer (messages = messages )
145139
146- prefix_user , prefix_assistant , suffix = extract_message_prefix_suffix (renderer , strip_trailing_spaces = not _continue )
147- prefix = prefix_user if impersonate else prefix_assistant
148-
149140 if state ['mode' ] == 'chat-instruct' :
150141 outer_messages = []
151142 if state ['custom_system_message' ].strip () != '' :
@@ -157,25 +148,29 @@ def make_prompt(messages):
157148 command = command .replace ('<|prompt|>' , prompt )
158149 command = replace_character_names (command , state ['name1' ], state ['name2' ])
159150
160-
161151 if _continue :
152+ prefix = get_generation_prompt (renderer , impersonate = impersonate , strip_trailing_spaces = False )[0 ]
162153 prefix += messages [- 1 ]["content" ]
163- elif not impersonate :
164- prefix = apply_extensions ('bot_prefix' , prefix , state )
154+ else :
155+ prefix = get_generation_prompt (renderer , impersonate = impersonate )[0 ]
156+ if not impersonate :
157+ prefix = apply_extensions ('bot_prefix' , prefix , state )
165158
166159 outer_messages .append ({"role" : "user" , "content" : command })
167160 outer_messages .append ({"role" : "assistant" , "content" : prefix })
168161
169162 prompt = instruction_template .render (messages = outer_messages )
163+ suffix = get_generation_prompt (instruct_renderer , impersonate = False )[1 ]
170164 if len (suffix ) > 0 :
171165 prompt = prompt [:- len (suffix )]
172166
173167 else :
174-
175168 if _continue :
169+ suffix = get_generation_prompt (renderer , impersonate = impersonate )[1 ]
176170 if len (suffix ) > 0 :
177171 prompt = prompt [:- len (suffix )]
178172 else :
173+ prefix = get_generation_prompt (renderer , impersonate = impersonate )[0 ]
179174 if state ['mode' ] == 'chat' and not impersonate :
180175 prefix = apply_extensions ('bot_prefix' , prefix , state )
181176
@@ -254,11 +249,15 @@ def get_stopping_strings(state):
254249 renderers .append (renderer )
255250
256251 for renderer in renderers :
257- prefix_user , prefix_assistant , suffix = extract_message_prefix_suffix (renderer )
258-
259- for item in [suffix + prefix_assistant , suffix + prefix_user , suffix ]:
260- stopping_strings .append (item )
261- stopping_strings .append (item .rstrip ())
252+ prefix_bot , suffix_bot = get_generation_prompt (renderer , impersonate = False )
253+ prefix_user , suffix_user = get_generation_prompt (renderer , impersonate = True )
254+
255+ stopping_strings += [
256+ suffix_user + prefix_bot ,
257+ suffix_user + prefix_user ,
258+ suffix_bot + prefix_bot ,
259+ suffix_bot + prefix_user ,
260+ ]
262261
263262 if 'stopping_strings' in state and isinstance (state ['stopping_strings' ], list ):
264263 stopping_strings += state .pop ('stopping_strings' )
0 commit comments