Skip to content

Commit 8a85e06

Browse files
committed
seperate llm invoke into a standalone function for tracing
Signed-off-by: Tsai, Louie <louie.tsai@intel.com>
1 parent 7a51d64 commit 8a85e06

4 files changed

Lines changed: 86 additions & 20 deletions

File tree

comps/agent/src/integrations/strategy/planexec/planner.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,15 @@ class grade(BaseModel):
6969
output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True)
7070
self.chain = plan_check_prompt | llm | output_parser
7171

72+
@opea_telemetry
73+
def __llm_invoke__(self, state):
74+
scored_result = self.chain.invoke(state)
75+
return scored_result
76+
7277
@opea_telemetry
7378
def __call__(self, state):
7479
# print("---CALL PlanStepChecker---")
75-
scored_result = self.chain.invoke(state)
80+
scored_result = self.__llm_invoke__(state)
7681
score = scored_result.binary_score
7782
print(f"Task is {state['context']}, Score is {score}")
7883
if score.startswith("yes"):
@@ -93,6 +98,11 @@ def __init__(self, llm, plan_checker=None, is_vllm=False):
9398
self.llm = planner_prompt | llm | output_parser
9499
self.plan_checker = plan_checker
95100

101+
@opea_telemetry
102+
def __llm_invoke__(self, messages):
103+
plan = self.llm.invoke(messages)
104+
return plan
105+
96106
@opea_telemetry
97107
def __call__(self, state):
98108
print("---CALL Planner---")
@@ -102,7 +112,7 @@ def __call__(self, state):
102112
while not success:
103113
while not success:
104114
try:
105-
plan = self.llm.invoke({"messages": [("user", state["messages"][-1].content)]})
115+
plan = self.__llm_invoke__({"messages": [("user", state["messages"][-1].content)]})
106116
print("Generated plan: ", plan)
107117
success = True
108118
except OutputParserException as e:
@@ -168,14 +178,19 @@ def __init__(self, llm, is_vllm=False):
168178
output_parser = PydanticToolsParser(tools=[Response], first_tool_only=True)
169179
self.llm = answer_make_prompt | llm | output_parser
170180

181+
@opea_telemetry
182+
def __llm_invoke__(self, state):
183+
output = self.llm.invoke(state)
184+
return output
185+
171186
@opea_telemetry
172187
def __call__(self, state):
173188
print("---CALL AnswerMaker---")
174189
success = False
175190
# sometime, LLM will not provide accurate steps per ask, try more than one time until success
176191
while not success:
177192
try:
178-
output = self.llm.invoke(state)
193+
output = self.__llm_invoke__(state)
179194
print("Generated response: ", output.response)
180195
success = True
181196
except OutputParserException as e:
@@ -205,10 +220,15 @@ class grade(BaseModel):
205220
output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True)
206221
self.chain = answer_check_prompt | llm | output_parser
207222

223+
@opea_telemetry
224+
def __llm_invoke__(self, state):
225+
output = self.chain.invoke(state)
226+
return output
227+
208228
@opea_telemetry
209229
def __call__(self, state):
210230
print("---CALL FinalAnswerChecker---")
211-
scored_result = self.chain.invoke(state)
231+
scored_result = self.__llm_invoke__(state)
212232
score = scored_result.binary_score
213233
print(f"Answer is {state['response']}, Grade of good response is {score}")
214234
if score.startswith("yes"):
@@ -225,14 +245,19 @@ def __init__(self, llm, answer_checker=None):
225245
self.llm = replanner_prompt | llm | output_parser
226246
self.answer_checker = answer_checker
227247

248+
@opea_telemetry
249+
def __llm_invoke__(self, state):
250+
output = self.llm.invoke(state)
251+
return output
252+
228253
@opea_telemetry
229254
def __call__(self, state):
230255
print("---CALL Replanner---")
231256
success = False
232257
# sometime, LLM will not provide accurate steps per ask, try more than one time until success
233258
while not success:
234259
try:
235-
output = self.llm.invoke(state)
260+
output = self.__llm_invoke__(state)
236261
success = True
237262
print("Replan: ", output)
238263
except OutputParserException as e:

comps/agent/src/integrations/strategy/ragagent/planner.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,17 @@ class QueryWriter:
4343
def __init__(self, llm, tools):
4444
self.llm = llm.bind_tools(tools)
4545

46+
@opea_telemetry
47+
def __llm_invoke__(self, messages):
48+
response = self.llm.invoke(messages)
49+
return response
50+
4651
@opea_telemetry
4752
def __call__(self, state):
4853
print("---CALL QueryWriter---")
4954
messages = state["messages"]
5055

51-
response = self.llm.invoke(messages)
56+
response = self.__llm_invoke__(messages)
5257
# We return a list, because this will get added to the existing list
5358
return {"messages": [response], "output": response}
5459

@@ -195,6 +200,11 @@ def __init__(self, args, tools):
195200
self.tools = tools
196201
self.chain = prompt | llm | output_parser
197202

203+
@opea_telemetry
204+
def __llm_invoke__(self, question, history, feedback):
205+
response = self.chain.invoke({"question": question, "history": history, "feedback": feedback})
206+
return response
207+
198208
@opea_telemetry
199209
def __call__(self, state):
200210
from .utils import assemble_history, convert_json_to_tool_call
@@ -206,7 +216,7 @@ def __call__(self, state):
206216
history = assemble_history(messages)
207217
feedback = instruction
208218

209-
response = self.chain.invoke({"question": question, "history": history, "feedback": feedback})
219+
response = self.__llm_invoke__(question, history, feedback)
210220
print("Response from query writer llm: ", response)
211221

212222
############ allow multiple tool calls in one AI message ############
@@ -244,6 +254,11 @@ def __init__(self, args):
244254
llm = setup_chat_model(args)
245255
self.chain = prompt | llm
246256

257+
@opea_telemetry
258+
def __llm_invoke__(self, question, docs):
259+
scored_result = self.chain.invoke({"question": question, "context": docs})
260+
return scored_result
261+
247262
@opea_telemetry
248263
def __call__(self, state) -> Literal["generate", "rewrite"]:
249264
from .utils import aggregate_docs
@@ -255,7 +270,7 @@ def __call__(self, state) -> Literal["generate", "rewrite"]:
255270
docs = aggregate_docs(messages)
256271
print("@@@@ Docs: ", docs)
257272

258-
scored_result = self.chain.invoke({"question": question, "context": docs})
273+
scored_result = self.__llm_invoke__(question, docs)
259274

260275
score = scored_result.content
261276
print("@@@@ Score: ", score)
@@ -287,6 +302,11 @@ def __init__(self, args):
287302
llm = setup_chat_model(args)
288303
self.rag_chain = prompt | llm
289304

305+
@opea_telemetry
306+
def __llm_invoke__(self, docs, question, query_time):
307+
response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time})
308+
return response
309+
290310
@opea_telemetry
291311
def __call__(self, state):
292312
from .utils import aggregate_docs
@@ -299,7 +319,7 @@ def __call__(self, state):
299319
question = messages[0].content
300320
docs = aggregate_docs(messages)
301321

302-
response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time})
322+
response = self.__llm_invoke__(docs, question, query_time)
303323
print("@@@@ Used this doc for generation:\n", docs)
304324
print("@@@@ Generated response: ", response)
305325
return {"messages": [response], "output": response}

comps/agent/src/integrations/strategy/react/planner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,14 @@ def __init__(self, tools, args, store=None, **kwargs):
213213
self.memory_type = args.memory_type
214214
self.store = store
215215

216+
@opea_telemetry
217+
def __llm_invoke__(self, query, history, tools_descriptions, thread_history):
218+
# invoke chain: raw output from llm
219+
response = self.chain.invoke(
220+
{"input": query, "history": history, "tools": tools_descriptions, "thread_history": thread_history}
221+
)
222+
return response
223+
216224
@opea_telemetry
217225
def __call__(self, state, config):
218226

@@ -245,9 +253,7 @@ def __call__(self, state, config):
245253
print("@@@ Tools description: ", tools_descriptions)
246254

247255
# invoke chain: raw output from llm
248-
response = self.chain.invoke(
249-
{"input": query, "history": history, "tools": tools_descriptions, "thread_history": thread_history}
250-
)
256+
response = self.__llm_invoke__(query, history, tools_descriptions, thread_history)
251257
response = response.content
252258

253259
# parse tool calls or answers from raw output: result is a list

comps/agent/src/integrations/strategy/sqlagent/planner.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ def __init__(self, args, tools):
6060
self.column_embeddings = self.embed_model.encode(self.values_descriptions)
6161
print("Done embedding column descriptions")
6262

63+
@opea_telemetry
64+
def __llm_invoke__(self, prompt):
65+
output = self.chain.invoke(prompt)
66+
return output
67+
6368
@opea_telemetry
6469
def __call__(self, state):
6570
print("----------Call Agent Node----------")
@@ -88,7 +93,7 @@ def __call__(self, state):
8893
history=history,
8994
)
9095

91-
output = self.chain.invoke(prompt)
96+
output = self.__llm_invoke__(prompt)
9297
output = self.output_parser.parse(
9398
output.content, history, table_schema, hints, question, state["messages"]
9499
) # text: str, history: str, db_schema: str, hint: str
@@ -195,6 +200,11 @@ def __init__(self, args, llm, tools):
195200
self.embed_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
196201
self.column_embeddings = self.embed_model.encode(self.values_descriptions)
197202

203+
@opea_telemetry
204+
def __llm_invoke__(self, chain, state):
205+
response = chain.invoke(state)
206+
return response
207+
198208
@opea_telemetry
199209
def __call__(self, state):
200210
print("----------Call Agent Node----------")
@@ -216,7 +226,7 @@ def __call__(self, state):
216226
)
217227

218228
chain = state_modifier_runnable | self.llm
219-
response = chain.invoke(state)
229+
response = self.__llm_invoke__(chain, state)
220230

221231
return {"messages": [response], "hint": hints}
222232

@@ -248,12 +258,7 @@ def get_sql_query_and_result(self, state):
248258
return query, result
249259

250260
@opea_telemetry
251-
def __call__(self, state):
252-
print("----------Call Query Fixer Node----------")
253-
table_schema, _ = get_table_schema(self.args.db_path)
254-
question = state["messages"][0].content
255-
hint = state["hint"]
256-
query, result = self.get_sql_query_and_result(state)
261+
def __llm_invoke__(self, table_schema, question, hint, query, result):
257262
response = self.chain.invoke(
258263
{
259264
"DATABASE_SCHEMA": table_schema,
@@ -263,6 +268,16 @@ def __call__(self, state):
263268
"RESULT": result,
264269
}
265270
)
271+
return response
272+
273+
@opea_telemetry
274+
def __call__(self, state):
275+
print("----------Call Query Fixer Node----------")
276+
table_schema, _ = get_table_schema(self.args.db_path)
277+
question = state["messages"][0].content
278+
hint = state["hint"]
279+
query, result = self.get_sql_query_and_result(state)
280+
response = self.__llm_invoke__(table_schema, question, hint, query, result)
266281
# print("@@@@@ Query fixer output:\n", response.content)
267282
return {"messages": [response]}
268283

0 commit comments

Comments
 (0)