From 0730c191dacbab83d90866f73848a8b62d903738 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 22 May 2024 08:18:02 -0700 Subject: [PATCH 1/4] Fix stopping strings for llama-3 and phi --- modules/chat.py | 70 ++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 43f5466bd8..b2d3232cf7 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -45,34 +45,35 @@ def str_presenter(dumper, data): yaml.representer.SafeRepresenter.add_representer(str, str_presenter) -def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True): +def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True): ''' - Given a Jinja template, reverse-engineers the prefix and the suffix for - an assistant message (if impersonate=False) or an user message - (if impersonate=True) + Given a Jinja template, extracts the prefix and suffix for + an assistant message or a user message. It assumes that they + share the same suffix. ''' - if impersonate: - messages = [ - {"role": "user", "content": "<<|user-message-1|>>"}, - {"role": "user", "content": "<<|user-message-2|>>"}, - ] - else: - messages = [ - {"role": "assistant", "content": "<<|user-message-1|>>"}, - {"role": "assistant", "content": "<<|user-message-2|>>"}, - ] + messages = [ + {"role": "user", "content": f"<<|user-message-1|>>"}, + {"role": "assistant", "content": f"<<|assistant-message-1|>>"}, + {"role": "user", "content": f"<<|user-message-2|>>"}, + {"role": "assistant", "content": f"<<|assistant-message-2|>>"}, + ] prompt = renderer(messages=messages) + unwanted_suffix = renderer(messages=[]) + + suffix = prompt.split('<<|assistant-message-2|>>')[1] + if unwanted_suffix != '': + suffix = suffix[:-len(unwanted_suffix)] - suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] - suffix = prompt.split("<<|user-message-2|>>")[1] - prefix = suffix_plus_prefix[len(suffix):] + prefix_user = prompt.split('<<|assistant-message-1|>>')[1].split('<<|user-message-2|>>')[0][len(suffix):] + prefix_assistant = prompt.split('<<|user-message-1|>>')[1].split('<<|assistant-message-1|>>')[0][len(suffix):] if strip_trailing_spaces: - prefix = prefix.rstrip(' ') + prefix_user = prefix_user.rstrip(' ') + prefix_assistant = prefix_assistant.rstrip(' ') - return prefix, suffix + return prefix_user, prefix_assistant, suffix def generate_chat_prompt(user_input, state, **kwargs): @@ -148,29 +149,30 @@ def make_prompt(messages): command = command.replace('<|prompt|>', prompt) command = replace_character_names(command, state['name1'], state['name2']) + prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue) + prefix = prefix_user if impersonate else prefix_assistant + if _continue: - prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] prefix += messages[-1]["content"] - else: - prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] - if not impersonate: - prefix = apply_extensions('bot_prefix', prefix, state) + elif not impersonate: + prefix = apply_extensions('bot_prefix', prefix, state) outer_messages.append({"role": "user", "content": command}) outer_messages.append({"role": "assistant", "content": prefix}) prompt = instruction_template.render(messages=outer_messages) - suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1] + suffix = extract_message_prefix_suffix(instruct_renderer, message_role="assistant")[2] if len(suffix) > 0: prompt = prompt[:-len(suffix)] else: + prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue) + if _continue: - suffix = get_generation_prompt(renderer, impersonate=impersonate)[1] if len(suffix) > 0: prompt = prompt[:-len(suffix)] else: - prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] + prefix = prefix_user if impersonate else prefix_assistant if state['mode'] == 'chat' and not impersonate: prefix = apply_extensions('bot_prefix', prefix, state) @@ -249,15 +251,11 @@ def get_stopping_strings(state): renderers.append(renderer) for renderer in renderers: - prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False) - prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True) - - stopping_strings += [ - suffix_user + prefix_bot, - suffix_user + prefix_user, - suffix_bot + prefix_bot, - suffix_bot + prefix_user, - ] + prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer) + + for item in [suffix + prefix_assistant, suffix + prefix_user, suffix]: + stopping_strings.append(item) + stopping_strings.append(item.rstrip()) if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings') From ee6728640b7a38d159cda76287b7ea2da122a75c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 22 May 2024 08:24:16 -0700 Subject: [PATCH 2/4] Minor fixes --- modules/chat.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index b2d3232cf7..9fb47ec815 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -53,10 +53,10 @@ def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True): ''' messages = [ - {"role": "user", "content": f"<<|user-message-1|>>"}, - {"role": "assistant", "content": f"<<|assistant-message-1|>>"}, - {"role": "user", "content": f"<<|user-message-2|>>"}, - {"role": "assistant", "content": f"<<|assistant-message-2|>>"}, + {"role": "user", "content": "<<|user-message-1|>>"}, + {"role": "assistant", "content": "<<|assistant-message-1|>>"}, + {"role": "user", "content": "<<|user-message-2|>>"}, + {"role": "assistant", "content": "<<|assistant-message-2|>>"}, ] prompt = renderer(messages=messages) @@ -126,9 +126,9 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "user", "content": user_input}) def remove_extra_bos(prompt): - for bos_token in ['', '<|startoftext|>', '', '<|endoftext|>']: - while prompt.startswith(bos_token): - prompt = prompt[len(bos_token):] + bos_token = shared.tokenizer.decode(shared.tokenizer.bos_token_id) + while prompt.startswith(bos_token): + prompt = prompt[len(bos_token):] return prompt @@ -161,7 +161,6 @@ def make_prompt(messages): outer_messages.append({"role": "assistant", "content": prefix}) prompt = instruction_template.render(messages=outer_messages) - suffix = extract_message_prefix_suffix(instruct_renderer, message_role="assistant")[2] if len(suffix) > 0: prompt = prompt[:-len(suffix)] From ce89dbc5ff3ce99a4a7887c5d2e3942e00ed4242 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 22 May 2024 09:42:33 -0700 Subject: [PATCH 3/4] Account for llama.cpp loader --- modules/chat.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 9fb47ec815..fde546495e 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -48,7 +48,7 @@ def str_presenter(dumper, data): def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True): ''' Given a Jinja template, extracts the prefix and suffix for - an assistant message or a user message. It assumes that they + an assistant message and a user message. It assumes that they share the same suffix. ''' @@ -126,9 +126,14 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "user", "content": user_input}) def remove_extra_bos(prompt): - bos_token = shared.tokenizer.decode(shared.tokenizer.bos_token_id) - while prompt.startswith(bos_token): - prompt = prompt[len(bos_token):] + if hasattr(shared.tokenizer, 'bos_token_id'): + bos_tokens = [shared.tokenizer.decode(shared.tokenizer.bos_token_id)] + else: + bos_tokens = ['', '<|startoftext|>', ''] + + for bos_token in bos_tokens: + while prompt.startswith(bos_token): + prompt = prompt[len(bos_token):] return prompt From bfb9d7fc019a937fb2b549c5c436ca16e50acfe1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 22 May 2024 09:44:14 -0700 Subject: [PATCH 4/4] Simplify --- modules/chat.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index fde546495e..6a388a0425 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -143,6 +143,9 @@ def make_prompt(messages): else: prompt = renderer(messages=messages) + prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue) + prefix = prefix_user if impersonate else prefix_assistant + if state['mode'] == 'chat-instruct': outer_messages = [] if state['custom_system_message'].strip() != '': @@ -154,8 +157,6 @@ def make_prompt(messages): command = command.replace('<|prompt|>', prompt) command = replace_character_names(command, state['name1'], state['name2']) - prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue) - prefix = prefix_user if impersonate else prefix_assistant if _continue: prefix += messages[-1]["content"] @@ -170,13 +171,11 @@ def make_prompt(messages): prompt = prompt[:-len(suffix)] else: - prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue) if _continue: if len(suffix) > 0: prompt = prompt[:-len(suffix)] else: - prefix = prefix_user if impersonate else prefix_assistant if state['mode'] == 'chat' and not impersonate: prefix = apply_extensions('bot_prefix', prefix, state)