Skip to content

Commit ad54d52

Browse files
committed
Revert "Fix stopping strings for llama-3 and phi (#6043)"
This reverts commit 5499bc9.
1 parent 5499bc9 commit ad54d52

File tree

1 file changed

+37
-38
lines changed

1 file changed

+37
-38
lines changed

modules/chat.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -45,35 +45,34 @@ def str_presenter(dumper, data):
4545
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
4646

4747

48-
def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True):
48+
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
4949
'''
50-
Given a Jinja template, extracts the prefix and suffix for
51-
an assistant message and a user message. It assumes that they
52-
share the same suffix.
50+
Given a Jinja template, reverse-engineers the prefix and the suffix for
51+
an assistant message (if impersonate=False) or an user message
52+
(if impersonate=True)
5353
'''
5454

55-
messages = [
56-
{"role": "user", "content": "<<|user-message-1|>>"},
57-
{"role": "assistant", "content": "<<|assistant-message-1|>>"},
58-
{"role": "user", "content": "<<|user-message-2|>>"},
59-
{"role": "assistant", "content": "<<|assistant-message-2|>>"},
60-
]
55+
if impersonate:
56+
messages = [
57+
{"role": "user", "content": "<<|user-message-1|>>"},
58+
{"role": "user", "content": "<<|user-message-2|>>"},
59+
]
60+
else:
61+
messages = [
62+
{"role": "assistant", "content": "<<|user-message-1|>>"},
63+
{"role": "assistant", "content": "<<|user-message-2|>>"},
64+
]
6165

6266
prompt = renderer(messages=messages)
63-
unwanted_suffix = renderer(messages=[])
64-
65-
suffix = prompt.split('<<|assistant-message-2|>>')[1]
66-
if unwanted_suffix != '':
67-
suffix = suffix[:-len(unwanted_suffix)]
6867

69-
prefix_user = prompt.split('<<|assistant-message-1|>>')[1].split('<<|user-message-2|>>')[0][len(suffix):]
70-
prefix_assistant = prompt.split('<<|user-message-1|>>')[1].split('<<|assistant-message-1|>>')[0][len(suffix):]
68+
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
69+
suffix = prompt.split("<<|user-message-2|>>")[1]
70+
prefix = suffix_plus_prefix[len(suffix):]
7171

7272
if strip_trailing_spaces:
73-
prefix_user = prefix_user.rstrip(' ')
74-
prefix_assistant = prefix_assistant.rstrip(' ')
73+
prefix = prefix.rstrip(' ')
7574

76-
return prefix_user, prefix_assistant, suffix
75+
return prefix, suffix
7776

7877

7978
def generate_chat_prompt(user_input, state, **kwargs):
@@ -126,12 +125,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
126125
messages.append({"role": "user", "content": user_input})
127126

128127
def remove_extra_bos(prompt):
129-
if hasattr(shared.tokenizer, 'bos_token_id'):
130-
bos_tokens = [shared.tokenizer.decode(shared.tokenizer.bos_token_id)]
131-
else:
132-
bos_tokens = ['<s>', '<|startoftext|>', '<BOS_TOKEN>']
133-
134-
for bos_token in bos_tokens:
128+
for bos_token in ['<s>', '<|startoftext|>', '<BOS_TOKEN>', '<|endoftext|>']:
135129
while prompt.startswith(bos_token):
136130
prompt = prompt[len(bos_token):]
137131

@@ -143,9 +137,6 @@ def make_prompt(messages):
143137
else:
144138
prompt = renderer(messages=messages)
145139

146-
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue)
147-
prefix = prefix_user if impersonate else prefix_assistant
148-
149140
if state['mode'] == 'chat-instruct':
150141
outer_messages = []
151142
if state['custom_system_message'].strip() != '':
@@ -157,25 +148,29 @@ def make_prompt(messages):
157148
command = command.replace('<|prompt|>', prompt)
158149
command = replace_character_names(command, state['name1'], state['name2'])
159150

160-
161151
if _continue:
152+
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
162153
prefix += messages[-1]["content"]
163-
elif not impersonate:
164-
prefix = apply_extensions('bot_prefix', prefix, state)
154+
else:
155+
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
156+
if not impersonate:
157+
prefix = apply_extensions('bot_prefix', prefix, state)
165158

166159
outer_messages.append({"role": "user", "content": command})
167160
outer_messages.append({"role": "assistant", "content": prefix})
168161

169162
prompt = instruction_template.render(messages=outer_messages)
163+
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
170164
if len(suffix) > 0:
171165
prompt = prompt[:-len(suffix)]
172166

173167
else:
174-
175168
if _continue:
169+
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
176170
if len(suffix) > 0:
177171
prompt = prompt[:-len(suffix)]
178172
else:
173+
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
179174
if state['mode'] == 'chat' and not impersonate:
180175
prefix = apply_extensions('bot_prefix', prefix, state)
181176

@@ -254,11 +249,15 @@ def get_stopping_strings(state):
254249
renderers.append(renderer)
255250

256251
for renderer in renderers:
257-
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer)
258-
259-
for item in [suffix + prefix_assistant, suffix + prefix_user, suffix]:
260-
stopping_strings.append(item)
261-
stopping_strings.append(item.rstrip())
252+
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
253+
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
254+
255+
stopping_strings += [
256+
suffix_user + prefix_bot,
257+
suffix_user + prefix_user,
258+
suffix_bot + prefix_bot,
259+
suffix_bot + prefix_user,
260+
]
262261

263262
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
264263
stopping_strings += state.pop('stopping_strings')

0 commit comments

Comments
 (0)