@@ -12,7 +12,7 @@ def boolean_string(string):
1212
1313
1414parser = argparse .ArgumentParser ()
15- parser .add_argument ("-u" , "--url" , type = str , default = "http://local :8000/v1" , help = "base url" )
15+ parser .add_argument ("-u" , "--url" , type = str , default = "http://localhost :8000/v1" , help = "base url" )
1616parser .add_argument ("-m" , "--model" , type = str , default = "xft" , help = "model name" )
1717parser .add_argument ("-t" , "--token" , type = str , default = "EMPTY" , help = "api key" )
1818parser .add_argument ("-i" , "--ip" , type = str , default = "0.0.0.0" , help = "gradio server ip" )
@@ -25,7 +25,7 @@ def clean_input():
2525
2626
2727def reset ():
28- return [], []
28+ return []
2929
3030
3131class ChatDemo :
@@ -43,7 +43,7 @@ def launch(self, server_name="0.0.0.0", server_port=7860, share=False):
4343 with gr .Blocks () as demo :
4444 self .html_func ()
4545
46- chatbot = gr .Chatbot ()
46+ chatbot = gr .Chatbot (type = "messages" )
4747 with gr .Row ():
4848 with gr .Column (scale = 2 ):
4949 user_input = gr .Textbox (show_label = False , placeholder = "Input..." , lines = 1 , container = False )
@@ -52,53 +52,59 @@ def launch(self, server_name="0.0.0.0", server_port=7860, share=False):
5252 with gr .Column (scale = 1 ):
5353 emptyBtn = gr .Button ("Clear History" )
5454
55- history = gr .State ([])
5655 submitBtn .click (
5756 self .predict ,
58- [user_input , chatbot , history ],
59- [chatbot , history ],
57+ [user_input , chatbot ],
58+ [chatbot ],
6059 show_progress = True ,
6160 )
6261 submitBtn .click (clean_input , [], [user_input ])
63- emptyBtn .click (reset , [], [chatbot , history ], show_progress = True )
62+ emptyBtn .click (reset , [], [chatbot ], show_progress = True )
6463
6564 demo .queue ().launch (server_name = server_name , server_port = server_port , share = share , inbrowser = True )
6665
67- def post_process_generation (self , chunk , chatbot , query , history ):
68- response = chatbot [- 1 ][1 ] + chunk
69- new_history = history + [(query , response )]
70- chatbot [- 1 ] = (query , response )
71- return chatbot , new_history
72-
73- def create_chat_input (self , query , history ):
66+ def post_process_generation (self , chunk , chatbot ):
67+ if chunk == "<think>" :
68+ chunk = ""
69+ chatbot [- 1 ]["metadata" ]= {"title" : "💭思考过程" }
70+ elif chunk == "</think>" :
71+ if len (chatbot [- 1 ]["content" ].strip ()) == 0 :
72+ del chatbot [- 1 ]["metadata" ]
73+ else :
74+ chatbot .append ({"role" : "assistant" , "content" : "" })
75+ chunk = ""
76+ chatbot [- 1 ]["content" ] += chunk
77+ return chatbot
78+
79+ def create_chat_input (self , chatbot ):
7480 msgs = []
75- if history is None :
76- history = []
77- for user_msg , model_msg in history :
78- msgs .append ({"role" : "user" , "content" : user_msg })
79- msgs .append ({"role" : "assistant" , "content" : model_msg })
80- msgs .append ({"role" : "user" , "content" : query })
81+ for msg in chatbot :
82+ if "metadata" not in msg or msg ["metadata" ] is None :
83+ msgs .append ({"role" : msg ["role" ], "content" : msg ["content" ]})
84+ if len (msgs ) > 8 :
85+ msgs = msgs [- 8 :]
8186 return msgs
8287
83- def predict (self , query , chatbot , history ):
84- chatbot .append ((query , "" ))
88+ def predict (self , query , chatbot ):
89+ chatbot .append ({"role" : "user" , "content" : query })
90+ chatbot .append ({"role" : "assistant" , "content" : "" })
8591
8692 completion = self .client .chat .completions .create (
8793 model = self .model ,
88- messages = self .create_chat_input (query , history ),
89- max_tokens = 2048 ,
94+ messages = self .create_chat_input (chatbot ),
95+ max_tokens = 8192 ,
9096 stream = True ,
91- temperature = 1.0 ,
97+ temperature = 0.6 ,
9298 extra_body = {"top_k" : 20 , "top_p" : 0.8 , "repetition_penalty" : 1.1 },
9399 )
94100
95101 for chunk in completion :
96102 if chunk .choices [0 ].delta .content is not None :
97- yield self .post_process_generation (chunk .choices [0 ].delta .content , chatbot , query , history )
103+ yield self .post_process_generation (chunk .choices [0 ].delta .content , chatbot )
98104
99105
100106if __name__ == "__main__" :
101107 args = parser .parse_args ()
102108 demo = ChatDemo (args .url , args .model , args .token )
103109
104- demo .launch (args .ip , args .port , args .share )
110+ demo .launch (args .ip , args .port , args .share )
0 commit comments