Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
80db2af
set hf_hub to 0.24.0
minmin-intel Aug 12, 2024
c99489d
add docgrader to agent strategy openai llm code passed
minmin-intel Aug 12, 2024
d6c347e
add nonstreaming output for agent
minmin-intel Aug 12, 2024
6b9b9fa
add react langgraph and tests
minmin-intel Aug 13, 2024
781c8e8
fix react langchain bug
minmin-intel Aug 13, 2024
d7469f1
Merge branch 'main' into agent-comp-dev
minmin-intel Aug 13, 2024
87a8080
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2024
0500725
fix test script
minmin-intel Aug 13, 2024
d213892
fix bug in test script
minmin-intel Aug 13, 2024
69129f1
update readme and rm old agentic-rag strategy
minmin-intel Aug 14, 2024
9539d20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2024
829d53c
Merge branch 'main' into agent-comp-dev
minmin-intel Aug 14, 2024
40672ff
Merge branch 'main' into agent-comp-dev
minmin-intel Aug 15, 2024
55853dd
update test and docgrader readme
minmin-intel Aug 15, 2024
48cc642
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
bc5e2ac
fix bug in test
minmin-intel Aug 15, 2024
5467f5d
Merge branch 'agent-comp-dev' of https://github.com/minmin-intel/GenA…
minmin-intel Aug 15, 2024
8840344
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
bad9d14
update test
minmin-intel Aug 15, 2024
8d8001a
Merge branch 'agent-comp-dev' of https://github.com/minmin-intel/GenA…
minmin-intel Aug 15, 2024
95bafa4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
a0f2c61
Merge branch 'main' into agent-comp-dev
minmin-intel Aug 16, 2024
0c379a5
update rag agent strategy name and update readme
minmin-intel Aug 16, 2024
1eaf009
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
a8d6e58
update test
minmin-intel Aug 16, 2024
7844227
Merge branch 'agent-comp-dev' of https://github.com/minmin-intel/GenA…
minmin-intel Aug 16, 2024
5c32f58
Merge branch 'main' into agent-comp-dev
minmin-intel Aug 19, 2024
60ee995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions comps/agent/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
comps_path = os.path.join(cur_path, "../../../")
sys.path.append(comps_path)

from comps import LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps import GeneratedDoc, LLMParamsDoc, ServiceType, opea_microservices, register_microservice
from comps.agent.langchain.src.agent import instantiate_agent
from comps.agent.langchain.src.utils import get_args

Expand All @@ -27,20 +27,26 @@
port=args.port,
input_datatype=LLMParamsDoc,
)
def llm_generate(input: LLMParamsDoc):
async def llm_generate(input: LLMParamsDoc):
# 1. initialize the agent
print("args: ", args)
input.streaming = args.streaming
config = {"recursion_limit": args.recursion_limit}
agent_inst = instantiate_agent(args, args.strategy)
print(type(agent_inst))

# 2. prepare the input for the agent
if input.streaming:
print("-----------STREAMING-------------")
return StreamingResponse(agent_inst.stream_generator(input.query, config), media_type="text/event-stream")

else:
# TODO: add support for non-streaming mode
return StreamingResponse(agent_inst.stream_generator(input.query, config), media_type="text/event-stream")
print("-----------NOT STREAMING-------------")
response = await agent_inst.non_streaming_run(input.query, config)
print("-----------Response-------------")
print(response)
return GeneratedDoc(text=response, prompt=input.query)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion comps/agent/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ docarray[full]
#used by tools
duckduckgo-search
fastapi
huggingface_hub
huggingface_hub==0.24.0
langchain #==0.1.12
langchain-huggingface
langchain-openai
Expand Down
16 changes: 11 additions & 5 deletions comps/agent/langchain/src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
# SPDX-License-Identifier: Apache-2.0


def instantiate_agent(args, strategy="react"):
if strategy == "react":
def instantiate_agent(args, strategy="react_langchain"):
if strategy == "react_langchain":
from .strategy.react import ReActAgentwithLangchain

return ReActAgentwithLangchain(args)
elif strategy == "react_langgraph":
from .strategy.react import ReActAgentwithLanggraph

return ReActAgentwithLanggraph(args)
elif strategy == "plan_execute":
from .strategy.planexec import PlanExecuteAgentWithLangGraph

Expand All @@ -15,7 +19,9 @@ def instantiate_agent(args, strategy="react"):
from .strategy.agentic_rag import RAGAgentwithLanggraph

return RAGAgentwithLanggraph(args)
else:
from .strategy.base_agent import BaseAgent, BaseAgentState
elif strategy == "docgrader":
from .strategy.docgrader import RAGAgentDocGraderV1

return BaseAgent(args)
return RAGAgentDocGraderV1(args)
else:
raise ValueError(f"Agent strategy: {strategy} not supported!")
3 changes: 3 additions & 0 deletions comps/agent/langchain/src/strategy/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ def compile(self):

def execute(self, state: dict):
pass

def non_streaming_run(self, query, config):
raise NotImplementedError
25 changes: 25 additions & 0 deletions comps/agent/langchain/src/strategy/docgrader/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Agentic Rag

This strategy is a practise provided with [LangGraph](https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag)
This agent strategy includes steps listed below:

1. RagAgent
decide if this query need to get extra help

- Yes: Goto 'Retriever'
- No: Complete the query with Final answer

2. Retriever:

- Get relative Info from tools, Goto 'DocumentGrader'

3. DocumentGrader
Judge retrieved info relevance based query

- Yes: Complete the query with Final answer
- No: Goto 'Rewriter'

4. Rewriter
Rewrite the query and Goto 'RagAgent'

![Agentic Rag Workflow](https://blog.langchain.dev/content/images/size/w1000/2024/02/image-16.png)
4 changes: 4 additions & 0 deletions comps/agent/langchain/src/strategy/docgrader/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .planner import RAGAgentDocGraderV1
246 changes: 246 additions & 0 deletions comps/agent/langchain/src/strategy/docgrader/planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Annotated, Any, Literal, Sequence, TypedDict

from langchain.output_parsers import PydanticOutputParser
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from ..base_agent import BaseAgent
from .prompt import DOC_GRADER_PROMPT, RAGv1_PROMPT

instruction = "Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again."
MAX_RETRY = 3


class AgentStateV1(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
output: str
doc_score: str
query_time: str


class RagAgent:
"""Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.

Args:
state (messages): The current state

Returns:
dict: The updated state with the agent response appended to messages
"""

def __init__(self, llm_endpoint, model_id, tools):
if isinstance(llm_endpoint, HuggingFaceEndpoint):
self.llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id).bind_tools(tools)
elif isinstance(llm_endpoint, ChatOpenAI):
self.llm = llm_endpoint.bind_tools(tools)

def __call__(self, state):
print("---CALL RagAgent---")
messages = state["messages"]

response = self.llm.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response], "output": response}


class Retriever:
@classmethod
def create(cls, tools_descriptions):
return ToolNode(tools_descriptions)


class DocumentGraderV1:
"""Determines whether the retrieved documents are relevant to the question.

Args:
state (messages): The current state

Returns:
str: A decision for whether the documents are relevant or not
"""

def __init__(self, llm_endpoint, model_id=None):
class grade(BaseModel):
"""Binary score for relevance check."""

binary_score: str = Field(description="Relevance score 'yes' or 'no'")

# Prompt
prompt = PromptTemplate(
template=DOC_GRADER_PROMPT,
input_variables=["context", "question"],
)

if isinstance(llm_endpoint, HuggingFaceEndpoint):
llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id).bind_tools([grade])
elif isinstance(llm_endpoint, ChatOpenAI):
llm = llm_endpoint.bind_tools([grade])
output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True)
self.chain = prompt | llm | output_parser

def __call__(self, state) -> Literal["generate", "rewrite"]:
print("---CALL DocumentGrader---")
messages = state["messages"]
last_message = messages[-1] # the latest retrieved doc

question = messages[0].content # the original query
docs = last_message.content

scored_result = self.chain.invoke({"question": question, "context": docs})

score = scored_result.binary_score

if score.startswith("yes"):
print("---DECISION: DOCS RELEVANT---")
return {"doc_score": "generate"}

else:
print(f"---DECISION: DOCS NOT RELEVANT, score is {score}---")

return {"messages": [HumanMessage(content=instruction)], "doc_score": "rewrite"}


class TextGeneratorV1:
"""Generate answer.

Args:
state (messages): The current state

Returns:
dict: The updated state with re-phrased question
"""

def __init__(self, llm_endpoint, model_id=None):
# Chain
# prompt = rlm_rag_prompt
prompt = RAGv1_PROMPT
self.rag_chain = prompt | llm_endpoint | StrOutputParser()

def __call__(self, state):
print("---GENERATE---")
messages = state["messages"]
question = messages[0].content
query_time = state["query_time"]

# find the latest retrieved doc
# which is a ToolMessage
for m in state["messages"][::-1]:
if isinstance(m, ToolMessage):
last_message = m
break

question = messages[0].content
docs = last_message.content

# Run
response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time})
print("@@@@ Used this doc for generation:\n", docs)
print("@@@@ Generated response: ", response)
return {"messages": [response], "output": response}


class RAGAgentDocGraderV1(BaseAgent):
def __init__(self, args):
super().__init__(args)

# Define Nodes
document_grader = DocumentGraderV1(self.llm_endpoint, args.model)
rag_agent = RagAgent(self.llm_endpoint, args.model, self.tools_descriptions)
text_generator = TextGeneratorV1(self.llm_endpoint)
retriever = Retriever.create(self.tools_descriptions)

# Define graph
workflow = StateGraph(AgentStateV1)

# Define the nodes we will cycle between
workflow.add_node("agent", rag_agent)
workflow.add_node("retrieve", retriever)
workflow.add_node("doc_grader", document_grader)
workflow.add_node("generate", text_generator)

# connect as graph
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
"agent",
tools_condition,
{
"tools": "retrieve", # if tools_condition return 'tools', then go to 'retrieve'
END: END, # if tools_condition return 'END', then go to END
},
)

workflow.add_edge("retrieve", "doc_grader")

workflow.add_conditional_edges(
"doc_grader",
self.should_retry,
{
False: "generate",
True: "agent",
},
)
workflow.add_edge("generate", END)

self.app = workflow.compile()

def should_retry(self, state):
# first check how many retry attempts have been made
num_retry = 0
for m in state["messages"]:
if instruction in m.content:
num_retry += 1

print("**********Num retry: ", num_retry)

if (num_retry < MAX_RETRY) and (state["doc_score"] == "rewrite"):
return True
else:
return False

def prepare_initial_state(self, query):
return {"messages": [HumanMessage(content=query)]}

async def stream_generator(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
async for event in self.app.astream(initial_state, config=config):
for node_name, node_state in event.items():
yield f"--- CALL {node_name} ---\n"
for k, v in node_state.items():
if v is not None:
yield f"{k}: {v}\n"

yield f"data: {repr(event)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield str(e)

async def non_streaming_run(self, query, config):
initial_state = self.prepare_initial_state(query)
try:
async for s in self.app.astream(initial_state, config=config, stream_mode="values"):
message = s["messages"][-1]
if isinstance(message, tuple):
print(message)
else:
message.pretty_print()

last_message = s["messages"][-1]
print("******Response: ", last_message.content)
return last_message.content
except Exception as e:
return str(e)
36 changes: 36 additions & 0 deletions comps/agent/langchain/src/strategy/docgrader/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from langchain_core.prompts import ChatPromptTemplate, PromptTemplate

DOC_GRADER_PROMPT = """\
Given the QUERY, determine if a relevant answer can be derived from the DOCUMENT.\n
QUERY: {question} \n
DOCUMENT:\n{context}\n\n
Give score 'yes' if the document provides sufficient and relevant information to answer the question. Otherwise, give score 'no'. ONLY answer with 'yes' or 'no'. NOTHING ELSE."""


PROMPT = """\
### You are a helpful, respectful and honest assistant.
You are given a Question and the time when it was asked in the Pacific Time Zone (PT), referred to as "Query
Time". The query time is formatted as "mm/dd/yyyy, hh:mm:ss PT".
Please follow these guidelines when formulating your answer:
1. If the question contains a false premise or assumption, answer “invalid question”.
2. If you are uncertain or do not know the answer, respond with “I don’t know”.
3. Refer to the search results to form your answer.
4. Give concise, factual and relevant answers.

### Search results: {context} \n
### Question: {question} \n
### Query Time: {time} \n
### Answer:
"""

RAGv1_PROMPT = ChatPromptTemplate.from_messages(
[
(
"human",
PROMPT,
),
]
)
Loading