Skip to content

Commit b586449

Browse files
authored
1 parent e9b43c7 commit b586449

File tree

2 files changed

+179
-9
lines changed

2 files changed

+179
-9
lines changed

code/rag/run_rag.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,56 @@
1010
rag_pipeline = RagPipeline(chunk_size=500, chunk_overlap=50, use_tools=False)
1111

1212

13+
def get_endpoint_config():
14+
config = rag_pipeline.get_endpoint_config()
15+
return [
16+
config.get("llm_model_name", ""),
17+
config.get("llm_api_key", ""),
18+
config.get("llm_url", ""),
19+
config.get("embeddings_model_name", ""),
20+
config.get("embeddings_api_key", ""),
21+
config.get("embeddings_url", ""),
22+
]
23+
24+
25+
def set_endpoint_config(
26+
llm_model_name,
27+
llm_api_key,
28+
llm_url,
29+
embeddings_model_name,
30+
embeddings_api_key,
31+
embeddings_url,
32+
):
33+
config = {
34+
"llm_model_name": llm_model_name,
35+
"llm_api_key": llm_api_key,
36+
"llm_url": llm_url,
37+
"embeddings_model_name": embeddings_model_name,
38+
"embeddings_api_key": embeddings_api_key,
39+
"embeddings_url": embeddings_url,
40+
}
41+
try:
42+
rag_pipeline.set_endpoint_config(config)
43+
error_msg = ""
44+
error_visible = False
45+
except Exception as e:
46+
error_msg = f"<span style='color:red; font-weight:bold;'>Error: {e}</span>"
47+
error_visible = True
48+
config_values = get_endpoint_config()
49+
doc_list, file_table = clear_document_list()
50+
return (
51+
*config_values,
52+
doc_list,
53+
file_table,
54+
gr.update(value=error_msg, visible=error_visible),
55+
)
56+
57+
58+
def toggle_api_key_visibility(visible, value):
59+
type = "text" if visible else "password"
60+
return gr.Textbox(label="API Key", type=type, value=value)
61+
62+
1363
def clear_history():
1464
new_id = uuid4()
1565
print(f"New thread_id: {new_id}")
@@ -95,6 +145,84 @@ def add_document(new_docs, doc_list):
95145
with gr.Column(scale=1):
96146
clear_doc_button = gr.ClearButton(value="Clear all documents")
97147

148+
with gr.Accordion("Endpoint Configuration", open=False):
149+
llm_model_name = gr.Textbox(label="LLM Model Name")
150+
llm_url = gr.Textbox(label="LLM URL")
151+
with gr.Row():
152+
llm_api_key = gr.Textbox(
153+
placeholder="LLM API Key",
154+
type="password",
155+
scale=4,
156+
show_label=False,
157+
)
158+
llm_api_key_visible = gr.Checkbox(
159+
label="Show LLM API Key",
160+
value=False,
161+
scale=1,
162+
)
163+
embeddings_model_name = gr.Textbox(label="Embeddings Model Name")
164+
embeddings_url = gr.Textbox(label="Embeddings URL")
165+
with gr.Row():
166+
embeddings_api_key = gr.Textbox(
167+
label="Embeddings API Key",
168+
type="password",
169+
scale=4,
170+
show_label=False,
171+
)
172+
embeddings_api_key_visible = gr.Checkbox(
173+
label="Show Embeddings API Key",
174+
value=False,
175+
scale=1,
176+
)
177+
save_btn = gr.Button("Save")
178+
config_error = gr.Markdown(value="", visible=False)
179+
180+
# Prefill on load
181+
demo.load(
182+
get_endpoint_config,
183+
inputs=None,
184+
outputs=[
185+
llm_model_name,
186+
llm_api_key,
187+
llm_url,
188+
embeddings_model_name,
189+
embeddings_api_key,
190+
embeddings_url,
191+
],
192+
)
193+
194+
save_btn.click(
195+
set_endpoint_config,
196+
inputs=[
197+
llm_model_name,
198+
llm_api_key,
199+
llm_url,
200+
embeddings_model_name,
201+
embeddings_api_key,
202+
embeddings_url,
203+
],
204+
outputs=[
205+
llm_model_name,
206+
llm_api_key,
207+
llm_url,
208+
embeddings_model_name,
209+
embeddings_api_key,
210+
embeddings_url,
211+
doc_list,
212+
file_table.dataset,
213+
config_error,
214+
],
215+
)
216+
llm_api_key_visible.change(
217+
toggle_api_key_visibility,
218+
inputs=[llm_api_key_visible, llm_api_key],
219+
outputs=llm_api_key,
220+
)
221+
embeddings_api_key_visible.change(
222+
toggle_api_key_visibility,
223+
inputs=[embeddings_api_key_visible, embeddings_api_key],
224+
outputs=embeddings_api_key,
225+
)
98226
chatbot.clear(clear_history, outputs=[uuid_state, chatbot])
99227

100228
chat_msg = chat_input.submit(

code/rag/src/rag_pipeline.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ def __init__(
5959
self.vector_store: InMemoryVectorStore
6060
self.prompt: PromptTemplate
6161
self.graph: StateGraph
62+
self.llm_model_name: str
63+
self.llm_api_key: str
64+
self.llm_url: str
65+
self.embeddings_model_name: str
66+
self.embeddings_api_key: str
67+
self.embeddings_url: str
68+
69+
self._set_endpoint_config()
6270
self._set_models()
6371
self._set_vector_store()
6472
self._set_graph()
@@ -94,23 +102,57 @@ def _check_env(self) -> None:
94102
"Please set the EMBEDDINGS_MODEL_NAME environment variable."
95103
)
96104

97-
def _set_models(self) -> None:
105+
def get_endpoint_config(self) -> dict:
106+
return {
107+
"llm_model_name": self.llm_model_name,
108+
"llm_api_key": self.llm_api_key,
109+
"llm_url": self.llm_url,
110+
"embeddings_model_name": self.embeddings_model_name,
111+
"embeddings_api_key": self.embeddings_api_key,
112+
"embeddings_url": self.embeddings_url,
113+
}
114+
115+
def set_endpoint_config(self, config: dict) -> None:
116+
for key, value in config.items():
117+
if key not in [
118+
"llm_model_name",
119+
"llm_api_key",
120+
"llm_url",
121+
"embeddings_model_name",
122+
"embeddings_api_key",
123+
"embeddings_url",
124+
]:
125+
raise ValueError(f"Invalid config key: {key}")
126+
setattr(self, key, value)
127+
self._set_models()
128+
self._set_vector_store()
129+
130+
def _set_endpoint_config(
131+
self,
132+
) -> None:
98133
self._check_env()
134+
self.llm_model_name = os.getenv("LLM_MODEL_NAME")
135+
self.llm_api_key = os.getenv("LLM_API_KEY")
136+
self.llm_url = os.getenv("LLM_URL")
137+
self.embeddings_model_name = os.getenv("EMBEDDINGS_MODEL_NAME")
138+
self.embeddings_api_key = os.getenv("EMBEDDINGS_API_KEY")
139+
self.embeddings_url = os.getenv("EMBEDDINGS_URL")
99140

100-
config = AutoConfig.from_pretrained(os.getenv("EMBEDDINGS_MODEL_NAME"))
141+
def _set_models(self) -> None:
142+
config = AutoConfig.from_pretrained(self.embeddings_model_name)
101143
assert self.chunk_size <= config.max_position_embeddings
102144

103145
llm = ChatOpenAI(
104-
model_name=os.getenv("LLM_MODEL_NAME"),
105-
openai_api_key=os.getenv("LLM_API_KEY"),
106-
openai_api_base=os.getenv("LLM_URL") + "/v1",
146+
model_name=self.llm_model_name,
147+
openai_api_key=self.llm_api_key,
148+
openai_api_base=self.llm_url + "/v1",
107149
)
108150

109151
embeddings = OpenAIEmbeddings(
110-
model=os.getenv("EMBEDDINGS_MODEL_NAME"),
111-
deployment=os.getenv("EMBEDDINGS_MODEL_NAME"),
112-
openai_api_key=os.getenv("EMBEDDINGS_API_KEY"),
113-
openai_api_base=os.getenv("EMBEDDINGS_URL") + "/v1",
152+
model=self.embeddings_model_name,
153+
deployment=self.embeddings_model_name,
154+
openai_api_key=self.embeddings_api_key,
155+
openai_api_base=self.embeddings_url + "/v1",
114156
tiktoken_enabled=False,
115157
)
116158
self.llm = llm

0 commit comments

Comments
 (0)