Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions comps/agent/src/integrations/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from ...storage.persistence_redis import RedisPersistence
from ...utils import filter_tools, has_multi_tool_inputs, tool_renderer
from ..base_agent import BaseAgent
from .prompt import REACT_SYS_MESSAGE, hwchase17_react_prompt


class ReActAgentwithLangchain(BaseAgent):
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
from .prompt import hwchase17_react_prompt

prompt = hwchase17_react_prompt
if has_multi_tool_inputs(self.tools_descriptions):
raise ValueError("Only supports single input tools when using strategy == react_langchain")
Expand Down Expand Up @@ -86,7 +87,12 @@ async def stream_generator(self, query, config, thread_id=None):
class ReActAgentwithLanggraph(BaseAgent):
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)

if kwargs.get("custom_prompt") is not None:
print("***Custom prompt is provided.")
REACT_SYS_MESSAGE = kwargs.get("custom_prompt").REACT_SYS_MESSAGE
else:
print("*** Using default prompt.")
from .prompt import REACT_SYS_MESSAGE
tools = self.tools_descriptions
print("REACT_SYS_MESSAGE: ", REACT_SYS_MESSAGE)

Expand Down Expand Up @@ -174,10 +180,18 @@ class ReActAgentNodeLlama:
A workaround for open-source llm served by TGI-gaudi.
"""

def __init__(self, tools, args, store=None):
from .prompt import REACT_AGENT_LLAMA_PROMPT
def __init__(self, tools, args, store=None, **kwargs):
from .utils import ReActLlamaOutputParser

if kwargs.get("custom_prompt") is not None:
print("***Custom prompt is provided.")
REACT_AGENT_LLAMA_PROMPT = kwargs.get("custom_prompt").REACT_AGENT_LLAMA_PROMPT
else:
print("*** Using default prompt.")
from .prompt import REACT_AGENT_LLAMA_PROMPT

print("***Prompt template:\n", REACT_AGENT_LLAMA_PROMPT)

output_parser = ReActLlamaOutputParser()
prompt = PromptTemplate(
template=REACT_AGENT_LLAMA_PROMPT,
Expand Down Expand Up @@ -244,6 +258,8 @@ def __call__(self, state, config):
ai_message = AIMessage(content=response, tool_calls=tool_calls)
elif "answer" in output[0]:
ai_message = AIMessage(content=str(output[0]["answer"]))
else:
ai_message = AIMessage(content=response)
else:
ai_message = AIMessage(content=response)

Expand All @@ -254,7 +270,7 @@ class ReActAgentLlama(BaseAgent):
def __init__(self, args, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)

agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args, store=self.store)
agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args, store=self.store, **kwargs)
tool_node = ToolNode(self.tools_descriptions)

workflow = StateGraph(AgentState)
Expand Down
50 changes: 10 additions & 40 deletions comps/agent/src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,13 @@
def test_agent_local(args):
from integrations.agent import instantiate_agent

if args.q == 0:
df = pd.DataFrame({"query": ["What is the Intel OPEA Project?"]})
elif args.q == 1:
df = pd.DataFrame({"query": ["what is the trade volume for Microsoft today?"]})
elif args.q == 2:
df = pd.DataFrame({"query": ["what is the hometown of Year 2023 Australia open winner?"]})

agent = instantiate_agent(args, strategy=args.strategy)
app = agent.app
agent = instantiate_agent(args)

config = {"recursion_limit": args.recursion_limit}

traces = []
success = 0
for _, row in df.iterrows():
print("Query: ", row["query"])
initial_state = {"messages": [{"role": "user", "content": row["query"]}]}
try:
trace = {"query": row["query"], "trace": []}
for event in app.stream(initial_state, config=config):
trace["trace"].append(event)
for k, v in event.items():
print("{}: {}".format(k, v))

traces.append(trace)
success += 1
except Exception as e:
print(str(e), str(traceback.format_exc()))
traces.append({"query": row["query"], "trace": str(e)})

print("-" * 50)
query = "What is OPEA project?"

df["trace"] = traces
df.to_csv(os.path.join(args.filedir, args.output), index=False)
print(f"succeed: {success}/{len(df)}")
# run_agent(agent, config, query)


def test_agent_http(args):
Expand Down Expand Up @@ -158,15 +130,12 @@ def test_ut(args):
def run_agent(agent, config, input_message):
initial_state = agent.prepare_initial_state(input_message)

try:
for s in agent.app.stream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
message.pretty_print()
for s in agent.app.stream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
message.pretty_print()

last_message = s["messages"][-1]
print("******Response: ", last_message.content)
except Exception as e:
print(str(e))
last_message = s["messages"][-1]
print("******Response: ", last_message.content)


def stream_generator(agent, config, input_message):
Expand Down Expand Up @@ -309,4 +278,5 @@ def test_memory(args):
# else:
# print("Please specify the test type")

test_memory(args)
# test_memory(args)
test_agent_local(args)
44 changes: 41 additions & 3 deletions comps/agent/src/tools/custom_prompt.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,49 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

REACT_SYS_MESSAGE = """\
Custom_prmpt !!!!!!!!!! Decompose the user request into a series of simple tasks when necessary and solve the problem step by step.
REACT_SYS_MESSAGE = """CUSTOM PROMPT
Decompose the user request into a series of simple tasks when necessary and solve the problem step by step.
When you cannot get the answer at first, do not give up. Reflect on the info you have from the tools and try to solve the problem in a different way.
Please follow these guidelines when formulating your answer:
1. If the question contains a false premise or assumption, answer “invalid question”.
2. If you are uncertain or do not know the answer, respond with “I dont know”.
2. If you are uncertain or do not know the answer, respond with “I don't know”.
3. Give concise, factual and relevant answers.
"""

REACT_AGENT_LLAMA_PROMPT = """FINANCIAL ANALYST ASSISTANT
You are a helpful assistant engaged in multi-turn conversations with Financial analysts.
You have access to the following two tools:
{tools}

**Procedure:**
1. Read the question carefully. Divide the question into sub-questions and conquer sub-questions one by one.
3. If there is execution history, read it carefully and reason about the information gathered so far and decide if you can answer the question or if you need to call more tools.

**Output format:**
You should output your thought process. Finish thinking first. Output tool calls or your answer at the end.
When making tool calls, you should use the following format:
TOOL CALL: {{"tool": "tool1", "args": {{"arg1": "value1", "arg2": "value2", ...}}}}

If you can answer the question, provide the answer in the following format:
FINAL ANSWER: {{"answer": "your answer here"}}


======= Conversations with user in previous turns =======
{thread_history}
======= End of previous conversations =======

======= Your execution History in this turn =========
{history}
======= End of execution history ==========

**Tips:**
* You may need to do multi-hop calculations and call tools multiple times to get an answer.
* Do not assume any financial figures. Always rely on the tools to get the factual information.
* If you need a certain financial figure, search for the figure instead of the financial statement name.
* If you did not get the answer at first, do not give up. Reflect on the steps that you have taken and try a different way. Think out of the box. You hard work will be rewarded.
* Give concise, factual and relevant answers.
* If the user question is too ambiguous, ask for clarification.

Now take a deep breath and think step by step to answer user's question in this turn.
USER MESSAGE: {input}
"""