Skip to content

Commit 3256626

Browse files
committed
+ Add thinking process for demo
Signed-off-by: Wenhuan Huang <wenhuan.huang@intel.com>
1 parent c00c729 commit 3256626

1 file changed

Lines changed: 33 additions & 27 deletions

File tree

examples/web_demo/web_demo_api.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def boolean_string(string):
1212

1313

1414
parser = argparse.ArgumentParser()
15-
parser.add_argument("-u", "--url", type=str, default="http://local:8000/v1", help="base url")
15+
parser.add_argument("-u", "--url", type=str, default="http://localhost:8000/v1", help="base url")
1616
parser.add_argument("-m", "--model", type=str, default="xft", help="model name")
1717
parser.add_argument("-t", "--token", type=str, default="EMPTY", help="api key")
1818
parser.add_argument("-i", "--ip", type=str, default="0.0.0.0", help="gradio server ip")
@@ -25,7 +25,7 @@ def clean_input():
2525

2626

2727
def reset():
28-
return [], []
28+
return []
2929

3030

3131
class ChatDemo:
@@ -43,7 +43,7 @@ def launch(self, server_name="0.0.0.0", server_port=7860, share=False):
4343
with gr.Blocks() as demo:
4444
self.html_func()
4545

46-
chatbot = gr.Chatbot()
46+
chatbot = gr.Chatbot(type="messages")
4747
with gr.Row():
4848
with gr.Column(scale=2):
4949
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=1, container=False)
@@ -52,53 +52,59 @@ def launch(self, server_name="0.0.0.0", server_port=7860, share=False):
5252
with gr.Column(scale=1):
5353
emptyBtn = gr.Button("Clear History")
5454

55-
history = gr.State([])
5655
submitBtn.click(
5756
self.predict,
58-
[user_input, chatbot, history],
59-
[chatbot, history],
57+
[user_input, chatbot],
58+
[chatbot],
6059
show_progress=True,
6160
)
6261
submitBtn.click(clean_input, [], [user_input])
63-
emptyBtn.click(reset, [], [chatbot, history], show_progress=True)
62+
emptyBtn.click(reset, [], [chatbot], show_progress=True)
6463

6564
demo.queue().launch(server_name=server_name, server_port=server_port, share=share, inbrowser=True)
6665

67-
def post_process_generation(self, chunk, chatbot, query, history):
68-
response = chatbot[-1][1] + chunk
69-
new_history = history + [(query, response)]
70-
chatbot[-1] = (query, response)
71-
return chatbot, new_history
72-
73-
def create_chat_input(self, query, history):
66+
def post_process_generation(self, chunk, chatbot):
67+
if chunk == "<think>":
68+
chunk=""
69+
chatbot[-1]["metadata"]= {"title": "💭思考过程"}
70+
elif chunk == "</think>":
71+
if len(chatbot[-1]["content"].strip()) == 0:
72+
del chatbot[-1]["metadata"]
73+
else:
74+
chatbot.append({"role": "assistant", "content": ""})
75+
chunk=""
76+
chatbot[-1]["content"] += chunk
77+
return chatbot
78+
79+
def create_chat_input(self, chatbot):
7480
msgs = []
75-
if history is None:
76-
history = []
77-
for user_msg, model_msg in history:
78-
msgs.append({"role": "user", "content": user_msg})
79-
msgs.append({"role": "assistant", "content": model_msg})
80-
msgs.append({"role": "user", "content": query})
81+
for msg in chatbot:
82+
if "metadata" not in msg or msg["metadata"] is None:
83+
msgs.append({"role": msg["role"], "content": msg["content"]})
84+
if len(msgs) > 8:
85+
msgs = msgs[-8:]
8186
return msgs
8287

83-
def predict(self, query, chatbot, history):
84-
chatbot.append((query, ""))
88+
def predict(self, query, chatbot):
89+
chatbot.append({"role": "user", "content": query})
90+
chatbot.append({"role": "assistant", "content": ""})
8591

8692
completion = self.client.chat.completions.create(
8793
model=self.model,
88-
messages=self.create_chat_input(query, history),
89-
max_tokens=2048,
94+
messages=self.create_chat_input(chatbot),
95+
max_tokens=8192,
9096
stream=True,
91-
temperature=1.0,
97+
temperature=0.6,
9298
extra_body={"top_k": 20, "top_p": 0.8, "repetition_penalty": 1.1},
9399
)
94100

95101
for chunk in completion:
96102
if chunk.choices[0].delta.content is not None:
97-
yield self.post_process_generation(chunk.choices[0].delta.content, chatbot, query, history)
103+
yield self.post_process_generation(chunk.choices[0].delta.content, chatbot)
98104

99105

100106
if __name__ == "__main__":
101107
args = parser.parse_args()
102108
demo = ChatDemo(args.url, args.model, args.token)
103109

104-
demo.launch(args.ip, args.port, args.share)
110+
demo.launch(args.ip, args.port, args.share)

0 commit comments

Comments
 (0)