Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions comps/agent/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@

db_client = None

logger.info("========initiating agent============")
logger.info("======== args ============")
logger.info(f"args: {args}")
agent_inst = instantiate_agent(args)


class AgentCompletionRequest(ChatCompletionRequest):
Expand Down Expand Up @@ -104,6 +103,8 @@ async def llm_generate(input: AgentCompletionRequest):
if args.with_memory:
config["configurable"] = {"thread_id": input.thread_id}

agent_inst = await instantiate_agent(args)

if logflag:
logger.info(type(agent_inst))

Expand Down Expand Up @@ -184,10 +185,10 @@ class CreateAssistant(CreateAssistantsRequest):
port=args.port,
)
@opea_telemetry
def create_assistants(input: CreateAssistant):
async def create_assistants(input: CreateAssistant):
# 1. initialize the agent
print("@@@ Initializing agent with config: ", input.agent_config)
agent_inst = instantiate_agent(input.agent_config)
agent_inst = await instantiate_agent(input.agent_config)
assistant_id = agent_inst.id
created_at = int(datetime.now().timestamp())
with assistants_global_kv as g_assistants:
Expand Down
87 changes: 54 additions & 33 deletions comps/agent/src/integrations/agent.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,72 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from .storage.persistence_redis import RedisPersistence
from .tools import get_mcp_tools, get_tools_descriptions
from .utils import load_python_prompt

agent = None

def instantiate_agent(args):

async def instantiate_agent(args):
global agent
strategy = args.strategy
with_memory = args.with_memory

if args.custom_prompt is not None:
print(f">>>>>> custom_prompt enabled, {args.custom_prompt}")
custom_prompt = load_python_prompt(args.custom_prompt)
else:
custom_prompt = None
# initialize tools
base_tools = get_tools_descriptions(getattr(args, "tools", None))
mcp_tools = await get_mcp_tools(args.mcp_sse_server_url) if getattr(args, "mcp_sse_server_url", None) else []
all_tools = base_tools + mcp_tools

if agent is None:

if args.custom_prompt is not None:
print(f">>>>>> custom_prompt enabled, {args.custom_prompt}")
custom_prompt = load_python_prompt(args.custom_prompt)
else:
custom_prompt = None

if strategy == "react_langchain":
from .strategy.react import ReActAgentwithLangchain

if strategy == "react_langchain":
from .strategy.react import ReActAgentwithLangchain
agent = ReActAgentwithLangchain(
args, tools_descriptions=all_tools, with_memory=with_memory, custom_prompt=custom_prompt
)
elif strategy == "react_langgraph":
from .strategy.react import ReActAgentwithLanggraph

return ReActAgentwithLangchain(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "react_langgraph":
from .strategy.react import ReActAgentwithLanggraph
agent = ReActAgentwithLanggraph(
args, tools_descriptions=all_tools, with_memory=with_memory, custom_prompt=custom_prompt
)
elif strategy == "react_llama":
print("Initializing ReAct Agent with LLAMA")
from .strategy.react import ReActAgentLlama

return ReActAgentwithLanggraph(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "react_llama":
print("Initializing ReAct Agent with LLAMA")
from .strategy.react import ReActAgentLlama
agent = ReActAgentLlama(args, tools_descriptions=all_tools, custom_prompt=custom_prompt)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

return ReActAgentLlama(args, custom_prompt=custom_prompt)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph
agent = PlanExecuteAgentWithLangGraph(
args, tools_descriptions=all_tools, with_memory=with_memory, custom_prompt=custom_prompt
)

return PlanExecuteAgentWithLangGraph(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "rag_agent" or strategy == "rag_agent_llama":
print("Initializing RAG Agent")
from .strategy.ragagent import RAGAgent

elif strategy == "rag_agent" or strategy == "rag_agent_llama":
print("Initializing RAG Agent")
from .strategy.ragagent import RAGAgent
agent = RAGAgent(args, tools_descriptions=all_tools, with_memory=with_memory, custom_prompt=custom_prompt)
elif strategy == "sql_agent_llama":
print("Initializing SQL Agent Llama")
from .strategy.sqlagent import SQLAgentLlama

return RAGAgent(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "sql_agent_llama":
print("Initializing SQL Agent Llama")
from .strategy.sqlagent import SQLAgentLlama
agent = SQLAgentLlama(
args, tools_descriptions=all_tools, with_memory=with_memory, custom_prompt=custom_prompt
)
elif strategy == "sql_agent":
print("Initializing SQL Agent")
from .strategy.sqlagent import SQLAgent

return SQLAgentLlama(args, with_memory, custom_prompt=custom_prompt)
elif strategy == "sql_agent":
print("Initializing SQL Agent")
from .strategy.sqlagent import SQLAgent
agent = SQLAgent(args, tools_descriptions=all_tools, with_memory=with_memory, custom_prompt=custom_prompt)
else:
raise ValueError(f"Agent strategy: {strategy} not supported!")

return SQLAgent(args, with_memory, custom_prompt=custom_prompt)
else:
raise ValueError(f"Agent strategy: {strategy} not supported!")
return agent
6 changes: 6 additions & 0 deletions comps/agent/src/integrations/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@
if os.environ.get("tools") is not None:
env_config += ["--tools", os.environ["tools"]]

if os.environ.get("mcp_sse_server_url") is not None:
env_config += ["--mcp_sse_server_url", os.environ["mcp_sse_server_url"]]

if os.environ.get("mcp_sse_server_api_key") is not None:
env_config += ["--mcp_sse_server_api_key", os.environ["mcp_sse_server_api_key"]]

if os.environ.get("stream") is not None:
env_config += ["--stream", os.environ["stream"]]

Expand Down
6 changes: 3 additions & 3 deletions comps/agent/src/integrations/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
from comps.cores.telemetry.opea_telemetry import opea_telemetry, tracer

from ..storage.persistence_redis import RedisPersistence
from ..tools import get_tools_descriptions
from ..utils import adapt_custom_prompt, setup_chat_model


class BaseAgent:
@opea_telemetry
def __init__(self, args, local_vars=None, **kwargs) -> None:
def __init__(self, args, tools_descriptions=None, local_vars=None, **kwargs) -> None:
self.llm = setup_chat_model(args)
self.tools_descriptions = get_tools_descriptions(args.tools)
self.tools_descriptions = tools_descriptions or []
self.app = None
self.id = f"assistant_{self.__class__.__name__}_{uuid4()}"

self.args = args
adapt_custom_prompt(local_vars, kwargs.get("custom_prompt"))
print("Registered tools: ", self.tools_descriptions)
Expand Down
1 change: 0 additions & 1 deletion comps/agent/src/integrations/strategy/planexec/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ This strategy is a practise provided with [LangGraph](https://github.com/langcha
4. CompletionChecker:

Judge on Replanner output

- option plan_executor: Goto "Executor"
- option END: Complete the query with Final answer.

Expand Down
4 changes: 2 additions & 2 deletions comps/agent/src/integrations/strategy/planexec/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def __call__(self, state):

class PlanExecuteAgentWithLangGraph(BaseAgent):
@opea_telemetry
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, tools_descriptions=None, with_memory=False, **kwargs):
super().__init__(args, tools_descriptions, local_vars=globals(), **kwargs)

# Define Node
plan_checker = PlanStepChecker(self.llm, is_vllm=self.is_vllm)
Expand Down
3 changes: 0 additions & 3 deletions comps/agent/src/integrations/strategy/ragagent/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@ This agent strategy includes steps listed below:

1. QueryWriter
This is an llm with tool calling capability, it decides if tool calls are needed to answer the user query or it can answer with llm's parametric knowledge.

- Yes: Rephrase the query in the form of a tool call to the Retriever tool, and send the rephrased query to 'Retriever'. The rephrasing is important as user queries may be not be clear and simply using user query may not retrieve relevant documents.
- No: Complete the query with Final answer

2. Retriever:

- Get related documents from a retrieval tool, then send the documents to 'DocumentGrader'. Note: The retrieval tool here is broad-sense, which can be a text retriever over a proprietary knowledge base, a websearch API, knowledge graph API, SQL database API etc.

3. DocumentGrader
Judge retrieved info relevance with respect to the user query

- Yes: Go to TextGenerator
- No: Go back to QueryWriter to rewrite query.

Expand Down
4 changes: 2 additions & 2 deletions comps/agent/src/integrations/strategy/ragagent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def create(cls, tools_descriptions):

class RAGAgent(BaseAgent):
@opea_telemetry
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, tools_descriptions=None, with_memory=False, **kwargs):
super().__init__(args, tools_descriptions, local_vars=globals(), **kwargs)

# Define Nodes
if args.strategy == "rag_agent":
Expand Down
12 changes: 6 additions & 6 deletions comps/agent/src/integrations/strategy/react/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

class ReActAgentwithLangchain(BaseAgent):
@opea_telemetry
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, tools_descriptions=None, with_memory=False, **kwargs):
super().__init__(args, tools_descriptions, local_vars=globals(), **kwargs)
from .prompt import hwchase17_react_prompt

prompt = hwchase17_react_prompt
Expand Down Expand Up @@ -90,8 +90,8 @@ async def stream_generator(self, query, config, thread_id=None):

class ReActAgentwithLanggraph(BaseAgent):
@opea_telemetry
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, tools_descriptions=None, with_memory=False, **kwargs):
super().__init__(args, tools_descriptions, local_vars=globals(), **kwargs)
if kwargs.get("custom_prompt") is not None:
print("***Custom prompt is provided.")
REACT_SYS_MESSAGE = kwargs.get("custom_prompt").REACT_SYS_MESSAGE
Expand Down Expand Up @@ -279,8 +279,8 @@ def __call__(self, state, config):

class ReActAgentLlama(BaseAgent):
@opea_telemetry
def __init__(self, args, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, tools_descriptions=None, **kwargs):
super().__init__(args, tools_descriptions, local_vars=globals(), **kwargs)

agent = ReActAgentNodeLlama(tools=self.tools_descriptions, args=args, store=self.store, **kwargs)
tool_node = ToolNode(self.tools_descriptions)
Expand Down
8 changes: 4 additions & 4 deletions comps/agent/src/integrations/strategy/sqlagent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ class SQLAgentLlama(BaseAgent):
# # db_name and db_path
# # use_hints, hints_file
@opea_telemetry
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, tools_descriptions=None, with_memory=False, **kwargs):
super().__init__(args, tools_descriptions, local_vars=globals(), **kwargs)
# note: here tools only include user defined tools
# we need to add the sql query tool as well
print("@@@@ user defined tools: ", self.tools_descriptions)
Expand Down Expand Up @@ -284,8 +284,8 @@ def __call__(self, state):

class SQLAgent(BaseAgent):
@opea_telemetry
def __init__(self, args, with_memory=False, **kwargs):
super().__init__(args, local_vars=globals(), **kwargs)
def __init__(self, args, tools_descriptions=None, with_memory=False, **kwargs):
super().__init__(args, tools_descriptions, local_vars=globals(), **kwargs)

sql_tool = get_sql_query_tool(args.db_path)
tools = self.tools_descriptions + [sql_tool]
Expand Down
18 changes: 18 additions & 0 deletions comps/agent/src/integrations/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,21 @@ def get_tools_descriptions(file_dir_path: str):
else:
pass
return tools


async def get_mcp_tools(mcp_sse_server_url):

mcp_tools = []
from langchain_mcp_adapters.client import MultiServerMCPClient

client = MultiServerMCPClient(
{
"math": {
"url": mcp_sse_server_url,
"transport": "sse",
}
}
)

mcp_tools = await client.get_tools()
return mcp_tools
2 changes: 2 additions & 0 deletions comps/agent/src/integrations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def get_args():
parser.add_argument("--strategy", type=str, default="react_langchain")
parser.add_argument("--role_description", type=str, default="LLM enhanced agent")
parser.add_argument("--tools", type=str, default=None, help="path to the tools file")
parser.add_argument("--mcp_sse_server_url", type=str, default=None, help="OPEA MCP SSE server URL")
parser.add_argument("--mcp_sse_server_api_key", type=str, default=None, help="OPEA MCP SSE server API key")
parser.add_argument("--recursion_limit", type=int, default=5)
parser.add_argument("--require_human_feedback", action="store_true", help="If this agent requires human feedback")
parser.add_argument("--debug", action="store_true", help="Test with endpoint mode")
Expand Down
2 changes: 2 additions & 0 deletions comps/agent/src/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ langchain
#used by tools
langchain-google-community
langchain-huggingface
langchain-mcp-adapters
langchain-openai
langchain-redis
langchain_community
langchainhub
langgraph
langsmith
mcp
mysql-connector-python
numpy

Expand Down
Loading
Loading