-
Notifications
You must be signed in to change notification settings - Fork 217
Add RagAgentDocGrader to agent comp #480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
80db2af
set hf_hub to 0.24.0
minmin-intel c99489d
add docgrader to agent strategy openai llm code passed
minmin-intel d6c347e
add nonstreaming output for agent
minmin-intel 6b9b9fa
add react langgraph and tests
minmin-intel 781c8e8
fix react langchain bug
minmin-intel d7469f1
Merge branch 'main' into agent-comp-dev
minmin-intel 87a8080
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0500725
fix test script
minmin-intel d213892
fix bug in test script
minmin-intel 69129f1
update readme and rm old agentic-rag strategy
minmin-intel 9539d20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 829d53c
Merge branch 'main' into agent-comp-dev
minmin-intel 40672ff
Merge branch 'main' into agent-comp-dev
minmin-intel 55853dd
update test and docgrader readme
minmin-intel 48cc642
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] bc5e2ac
fix bug in test
minmin-intel 5467f5d
Merge branch 'agent-comp-dev' of https://github.com/minmin-intel/GenA…
minmin-intel 8840344
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] bad9d14
update test
minmin-intel 8d8001a
Merge branch 'agent-comp-dev' of https://github.com/minmin-intel/GenA…
minmin-intel 95bafa4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a0f2c61
Merge branch 'main' into agent-comp-dev
minmin-intel 0c379a5
update rag agent strategy name and update readme
minmin-intel 1eaf009
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a8d6e58
update test
minmin-intel 7844227
Merge branch 'agent-comp-dev' of https://github.com/minmin-intel/GenA…
minmin-intel 5c32f58
Merge branch 'main' into agent-comp-dev
minmin-intel 60ee995
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Agentic Rag | ||
|
|
||
| This strategy is a practise provided with [LangGraph](https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag) | ||
| This agent strategy includes steps listed below: | ||
|
|
||
| 1. RagAgent | ||
| decide if this query need to get extra help | ||
|
|
||
| - Yes: Goto 'Retriever' | ||
| - No: Complete the query with Final answer | ||
|
|
||
| 2. Retriever: | ||
|
|
||
| - Get relative Info from tools, Goto 'DocumentGrader' | ||
|
|
||
| 3. DocumentGrader | ||
| Judge retrieved info relevance based query | ||
|
|
||
| - Yes: Complete the query with Final answer | ||
| - No: Goto 'Rewriter' | ||
|
|
||
| 4. Rewriter | ||
| Rewrite the query and Goto 'RagAgent' | ||
|
|
||
|  | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| # Copyright (C) 2024 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from .planner import RAGAgentDocGraderV1 |
246 changes: 246 additions & 0 deletions
246
comps/agent/langchain/src/strategy/docgrader/planner.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,246 @@ | ||
| # Copyright (C) 2024 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from typing import Annotated, Any, Literal, Sequence, TypedDict | ||
|
|
||
| from langchain.output_parsers import PydanticOutputParser | ||
| from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage | ||
| from langchain_core.output_parsers import StrOutputParser | ||
| from langchain_core.output_parsers.openai_tools import PydanticToolsParser | ||
| from langchain_core.prompts import PromptTemplate | ||
| from langchain_core.pydantic_v1 import BaseModel, Field | ||
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | ||
| from langchain_openai import ChatOpenAI | ||
| from langgraph.graph import END, START, StateGraph | ||
| from langgraph.graph.message import add_messages | ||
| from langgraph.prebuilt import ToolNode, tools_condition | ||
|
|
||
| from ..base_agent import BaseAgent | ||
| from .prompt import DOC_GRADER_PROMPT, RAGv1_PROMPT | ||
|
|
||
| instruction = "Retrieved document is not sufficient or relevant to answer the query. Reformulate the query to search knowledge base again." | ||
| MAX_RETRY = 3 | ||
|
|
||
|
|
||
| class AgentStateV1(TypedDict): | ||
minmin-intel marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # The add_messages function defines how an update should be processed | ||
| # Default is to replace. add_messages says "append" | ||
| messages: Annotated[Sequence[BaseMessage], add_messages] | ||
| output: str | ||
| doc_score: str | ||
| query_time: str | ||
|
|
||
|
|
||
| class RagAgent: | ||
| """Invokes the agent model to generate a response based on the current state. Given | ||
| the question, it will decide to retrieve using the retriever tool, or simply end. | ||
|
|
||
| Args: | ||
| state (messages): The current state | ||
|
|
||
| Returns: | ||
| dict: The updated state with the agent response appended to messages | ||
| """ | ||
|
|
||
| def __init__(self, llm_endpoint, model_id, tools): | ||
| if isinstance(llm_endpoint, HuggingFaceEndpoint): | ||
| self.llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id).bind_tools(tools) | ||
| elif isinstance(llm_endpoint, ChatOpenAI): | ||
| self.llm = llm_endpoint.bind_tools(tools) | ||
|
|
||
| def __call__(self, state): | ||
| print("---CALL RagAgent---") | ||
| messages = state["messages"] | ||
|
|
||
| response = self.llm.invoke(messages) | ||
| # We return a list, because this will get added to the existing list | ||
| return {"messages": [response], "output": response} | ||
|
|
||
|
|
||
| class Retriever: | ||
| @classmethod | ||
| def create(cls, tools_descriptions): | ||
| return ToolNode(tools_descriptions) | ||
|
|
||
|
|
||
| class DocumentGraderV1: | ||
| """Determines whether the retrieved documents are relevant to the question. | ||
|
|
||
| Args: | ||
| state (messages): The current state | ||
|
|
||
| Returns: | ||
| str: A decision for whether the documents are relevant or not | ||
| """ | ||
|
|
||
| def __init__(self, llm_endpoint, model_id=None): | ||
| class grade(BaseModel): | ||
| """Binary score for relevance check.""" | ||
|
|
||
| binary_score: str = Field(description="Relevance score 'yes' or 'no'") | ||
|
|
||
| # Prompt | ||
| prompt = PromptTemplate( | ||
| template=DOC_GRADER_PROMPT, | ||
| input_variables=["context", "question"], | ||
| ) | ||
|
|
||
| if isinstance(llm_endpoint, HuggingFaceEndpoint): | ||
| llm = ChatHuggingFace(llm=llm_endpoint, model_id=model_id).bind_tools([grade]) | ||
| elif isinstance(llm_endpoint, ChatOpenAI): | ||
| llm = llm_endpoint.bind_tools([grade]) | ||
| output_parser = PydanticToolsParser(tools=[grade], first_tool_only=True) | ||
| self.chain = prompt | llm | output_parser | ||
|
|
||
| def __call__(self, state) -> Literal["generate", "rewrite"]: | ||
| print("---CALL DocumentGrader---") | ||
| messages = state["messages"] | ||
| last_message = messages[-1] # the latest retrieved doc | ||
|
|
||
| question = messages[0].content # the original query | ||
| docs = last_message.content | ||
|
|
||
| scored_result = self.chain.invoke({"question": question, "context": docs}) | ||
|
|
||
| score = scored_result.binary_score | ||
|
|
||
| if score.startswith("yes"): | ||
| print("---DECISION: DOCS RELEVANT---") | ||
| return {"doc_score": "generate"} | ||
|
|
||
| else: | ||
| print(f"---DECISION: DOCS NOT RELEVANT, score is {score}---") | ||
|
|
||
| return {"messages": [HumanMessage(content=instruction)], "doc_score": "rewrite"} | ||
|
|
||
|
|
||
| class TextGeneratorV1: | ||
| """Generate answer. | ||
|
|
||
| Args: | ||
| state (messages): The current state | ||
|
|
||
| Returns: | ||
| dict: The updated state with re-phrased question | ||
| """ | ||
|
|
||
| def __init__(self, llm_endpoint, model_id=None): | ||
| # Chain | ||
| # prompt = rlm_rag_prompt | ||
| prompt = RAGv1_PROMPT | ||
| self.rag_chain = prompt | llm_endpoint | StrOutputParser() | ||
|
|
||
| def __call__(self, state): | ||
| print("---GENERATE---") | ||
| messages = state["messages"] | ||
| question = messages[0].content | ||
| query_time = state["query_time"] | ||
|
|
||
| # find the latest retrieved doc | ||
| # which is a ToolMessage | ||
| for m in state["messages"][::-1]: | ||
| if isinstance(m, ToolMessage): | ||
| last_message = m | ||
| break | ||
|
|
||
| question = messages[0].content | ||
| docs = last_message.content | ||
|
|
||
| # Run | ||
| response = self.rag_chain.invoke({"context": docs, "question": question, "time": query_time}) | ||
| print("@@@@ Used this doc for generation:\n", docs) | ||
| print("@@@@ Generated response: ", response) | ||
| return {"messages": [response], "output": response} | ||
|
|
||
|
|
||
| class RAGAgentDocGraderV1(BaseAgent): | ||
| def __init__(self, args): | ||
| super().__init__(args) | ||
|
|
||
| # Define Nodes | ||
| document_grader = DocumentGraderV1(self.llm_endpoint, args.model) | ||
| rag_agent = RagAgent(self.llm_endpoint, args.model, self.tools_descriptions) | ||
| text_generator = TextGeneratorV1(self.llm_endpoint) | ||
| retriever = Retriever.create(self.tools_descriptions) | ||
|
|
||
| # Define graph | ||
| workflow = StateGraph(AgentStateV1) | ||
|
|
||
| # Define the nodes we will cycle between | ||
| workflow.add_node("agent", rag_agent) | ||
| workflow.add_node("retrieve", retriever) | ||
| workflow.add_node("doc_grader", document_grader) | ||
| workflow.add_node("generate", text_generator) | ||
|
|
||
| # connect as graph | ||
| workflow.add_edge(START, "agent") | ||
| workflow.add_conditional_edges( | ||
| "agent", | ||
| tools_condition, | ||
| { | ||
| "tools": "retrieve", # if tools_condition return 'tools', then go to 'retrieve' | ||
| END: END, # if tools_condition return 'END', then go to END | ||
| }, | ||
| ) | ||
|
|
||
| workflow.add_edge("retrieve", "doc_grader") | ||
|
|
||
| workflow.add_conditional_edges( | ||
| "doc_grader", | ||
| self.should_retry, | ||
| { | ||
| False: "generate", | ||
| True: "agent", | ||
| }, | ||
| ) | ||
| workflow.add_edge("generate", END) | ||
|
|
||
| self.app = workflow.compile() | ||
|
|
||
| def should_retry(self, state): | ||
| # first check how many retry attempts have been made | ||
| num_retry = 0 | ||
| for m in state["messages"]: | ||
| if instruction in m.content: | ||
| num_retry += 1 | ||
|
|
||
| print("**********Num retry: ", num_retry) | ||
|
|
||
| if (num_retry < MAX_RETRY) and (state["doc_score"] == "rewrite"): | ||
| return True | ||
| else: | ||
| return False | ||
|
|
||
| def prepare_initial_state(self, query): | ||
| return {"messages": [HumanMessage(content=query)]} | ||
|
|
||
| async def stream_generator(self, query, config): | ||
| initial_state = self.prepare_initial_state(query) | ||
| try: | ||
| async for event in self.app.astream(initial_state, config=config): | ||
| for node_name, node_state in event.items(): | ||
| yield f"--- CALL {node_name} ---\n" | ||
| for k, v in node_state.items(): | ||
| if v is not None: | ||
| yield f"{k}: {v}\n" | ||
|
|
||
| yield f"data: {repr(event)}\n\n" | ||
| yield "data: [DONE]\n\n" | ||
| except Exception as e: | ||
| yield str(e) | ||
|
|
||
| async def non_streaming_run(self, query, config): | ||
| initial_state = self.prepare_initial_state(query) | ||
| try: | ||
| async for s in self.app.astream(initial_state, config=config, stream_mode="values"): | ||
| message = s["messages"][-1] | ||
| if isinstance(message, tuple): | ||
| print(message) | ||
| else: | ||
| message.pretty_print() | ||
|
|
||
| last_message = s["messages"][-1] | ||
| print("******Response: ", last_message.content) | ||
| return last_message.content | ||
| except Exception as e: | ||
| return str(e) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| # Copyright (C) 2024 Intel Corporation | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from langchain_core.prompts import ChatPromptTemplate, PromptTemplate | ||
|
|
||
| DOC_GRADER_PROMPT = """\ | ||
| Given the QUERY, determine if a relevant answer can be derived from the DOCUMENT.\n | ||
| QUERY: {question} \n | ||
| DOCUMENT:\n{context}\n\n | ||
| Give score 'yes' if the document provides sufficient and relevant information to answer the question. Otherwise, give score 'no'. ONLY answer with 'yes' or 'no'. NOTHING ELSE.""" | ||
|
|
||
|
|
||
| PROMPT = """\ | ||
| ### You are a helpful, respectful and honest assistant. | ||
| You are given a Question and the time when it was asked in the Pacific Time Zone (PT), referred to as "Query | ||
| Time". The query time is formatted as "mm/dd/yyyy, hh:mm:ss PT". | ||
| Please follow these guidelines when formulating your answer: | ||
| 1. If the question contains a false premise or assumption, answer “invalid question”. | ||
| 2. If you are uncertain or do not know the answer, respond with “I don’t know”. | ||
| 3. Refer to the search results to form your answer. | ||
| 4. Give concise, factual and relevant answers. | ||
|
|
||
| ### Search results: {context} \n | ||
| ### Question: {question} \n | ||
| ### Query Time: {time} \n | ||
| ### Answer: | ||
| """ | ||
|
|
||
| RAGv1_PROMPT = ChatPromptTemplate.from_messages( | ||
| [ | ||
| ( | ||
| "human", | ||
| PROMPT, | ||
| ), | ||
| ] | ||
| ) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.