Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/web_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
gradio==5.11.0
gradio==5.18.0
sentencepiece==0.1.99
transformers==4.41.2
transformers_stream_generator==0.0.5
Expand Down
62 changes: 35 additions & 27 deletions examples/web_demo/web_demo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def boolean_string(string):


parser = argparse.ArgumentParser()
parser.add_argument("-u", "--url", type=str, default="http://local:8000/v1", help="base url")
parser.add_argument("-u", "--url", type=str, default="http://localhost:8000/v1", help="base url")
parser.add_argument("-m", "--model", type=str, default="xft", help="model name")
parser.add_argument("-t", "--token", type=str, default="EMPTY", help="api key")
parser.add_argument("-i", "--ip", type=str, default="0.0.0.0", help="gradio server ip")
Expand All @@ -25,7 +25,7 @@ def clean_input():


def reset():
return [], []
return []


class ChatDemo:
Expand All @@ -43,7 +43,7 @@ def launch(self, server_name="0.0.0.0", server_port=7860, share=False):
with gr.Blocks() as demo:
self.html_func()

chatbot = gr.Chatbot()
chatbot = gr.Chatbot(type="messages")
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=1, container=False)
Expand All @@ -52,53 +52,61 @@ def launch(self, server_name="0.0.0.0", server_port=7860, share=False):
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")

history = gr.State([])
submitBtn.click(
self.predict,
[user_input, chatbot, history],
[chatbot, history],
[user_input, chatbot],
[chatbot],
show_progress=True,
)
submitBtn.click(clean_input, [], [user_input])
emptyBtn.click(reset, [], [chatbot, history], show_progress=True)
emptyBtn.click(reset, [], [chatbot], show_progress=True)

demo.queue().launch(server_name=server_name, server_port=server_port, share=share, inbrowser=True)

def post_process_generation(self, chunk, chatbot, query, history):
response = chatbot[-1][1] + chunk
new_history = history + [(query, response)]
chatbot[-1] = (query, response)
return chatbot, new_history

def create_chat_input(self, query, history):
def post_process_generation(self, chunk, chatbot):
if chunk == "<think>":
chunk=""
chatbot[-1]["metadata"]= {"title": "💭思考过程"}
elif chunk == "</think>":
if "metadata" in chatbot[-1]:
if len(chatbot[-1]["content"].strip()) == 0:
if chatbot[-1]["metadata"] is not None:
del chatbot[-1]["metadata"]
else:
chatbot.append({"role": "assistant", "content": ""})
chunk=""
chatbot[-1]["content"] += chunk
return chatbot

def create_chat_input(self, chatbot):
msgs = []
if history is None:
history = []
for user_msg, model_msg in history:
msgs.append({"role": "user", "content": user_msg})
msgs.append({"role": "assistant", "content": model_msg})
msgs.append({"role": "user", "content": query})
for msg in chatbot:
if "metadata" not in msg or msg["metadata"] is None:
msgs.append({"role": msg["role"], "content": msg["content"]})
if len(msgs) > 8:
msgs = msgs[-8:]
return msgs

def predict(self, query, chatbot, history):
chatbot.append((query, ""))
def predict(self, query, chatbot):
chatbot.append({"role": "user", "content": query})
chatbot.append({"role": "assistant", "content": ""})

completion = self.client.chat.completions.create(
model=self.model,
messages=self.create_chat_input(query, history),
max_tokens=2048,
messages=self.create_chat_input(chatbot),
max_tokens=8192,
stream=True,
temperature=1.0,
temperature=0.6,
extra_body={"top_k": 20, "top_p": 0.8, "repetition_penalty": 1.1},
)

for chunk in completion:
if chunk.choices[0].delta.content is not None:
yield self.post_process_generation(chunk.choices[0].delta.content, chatbot, query, history)
yield self.post_process_generation(chunk.choices[0].delta.content, chatbot)


if __name__ == "__main__":
args = parser.parse_args()
demo = ChatDemo(args.url, args.model, args.token)

demo.launch(args.ip, args.port, args.share)
demo.launch(args.ip, args.port, args.share)