Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions modules/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,35 @@ def str_presenter(dumper, data):
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)


def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True):
'''
Given a Jinja template, reverse-engineers the prefix and the suffix for
an assistant message (if impersonate=False) or an user message
(if impersonate=True)
Given a Jinja template, extracts the prefix and suffix for
an assistant message and a user message. It assumes that they
share the same suffix.
'''

if impersonate:
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
]
else:
messages = [
{"role": "assistant", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|user-message-2|>>"},
]
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|assistant-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
{"role": "assistant", "content": "<<|assistant-message-2|>>"},
]

prompt = renderer(messages=messages)
unwanted_suffix = renderer(messages=[])

suffix = prompt.split('<<|assistant-message-2|>>')[1]
if unwanted_suffix != '':
suffix = suffix[:-len(unwanted_suffix)]

suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
suffix = prompt.split("<<|user-message-2|>>")[1]
prefix = suffix_plus_prefix[len(suffix):]
prefix_user = prompt.split('<<|assistant-message-1|>>')[1].split('<<|user-message-2|>>')[0][len(suffix):]
prefix_assistant = prompt.split('<<|user-message-1|>>')[1].split('<<|assistant-message-1|>>')[0][len(suffix):]

if strip_trailing_spaces:
prefix = prefix.rstrip(' ')
prefix_user = prefix_user.rstrip(' ')
prefix_assistant = prefix_assistant.rstrip(' ')

return prefix, suffix
return prefix_user, prefix_assistant, suffix


def generate_chat_prompt(user_input, state, **kwargs):
Expand Down Expand Up @@ -125,7 +126,12 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "user", "content": user_input})

def remove_extra_bos(prompt):
for bos_token in ['<s>', '<|startoftext|>', '<BOS_TOKEN>', '<|endoftext|>']:
if hasattr(shared.tokenizer, 'bos_token_id'):
bos_tokens = [shared.tokenizer.decode(shared.tokenizer.bos_token_id)]
else:
bos_tokens = ['<s>', '<|startoftext|>', '<BOS_TOKEN>']

for bos_token in bos_tokens:
while prompt.startswith(bos_token):
prompt = prompt[len(bos_token):]

Expand All @@ -137,6 +143,9 @@ def make_prompt(messages):
else:
prompt = renderer(messages=messages)

prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue)
prefix = prefix_user if impersonate else prefix_assistant

if state['mode'] == 'chat-instruct':
outer_messages = []
if state['custom_system_message'].strip() != '':
Expand All @@ -148,29 +157,25 @@ def make_prompt(messages):
command = command.replace('<|prompt|>', prompt)
command = replace_character_names(command, state['name1'], state['name2'])


if _continue:
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix += messages[-1]["content"]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
elif not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)

outer_messages.append({"role": "user", "content": command})
outer_messages.append({"role": "assistant", "content": prefix})

prompt = instruction_template.render(messages=outer_messages)
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]

else:

if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
if len(suffix) > 0:
prompt = prompt[:-len(suffix)]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)

Expand Down Expand Up @@ -249,15 +254,11 @@ def get_stopping_strings(state):
renderers.append(renderer)

for renderer in renderers:
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)

stopping_strings += [
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer)

for item in [suffix + prefix_assistant, suffix + prefix_user, suffix]:
stopping_strings.append(item)
stopping_strings.append(item.rstrip())

if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings')
Expand Down