Skip to content

Commit 8692892

Browse files
reidliu41mawong-amd
authored andcommitted
[doc] Add RAG Integration example (vllm-project#17692)
Signed-off-by: reidliu41 <[email protected]> Co-authored-by: reidliu41 <[email protected]>
1 parent 909425e commit 8692892

File tree

4 files changed

+551
-0
lines changed

4 files changed

+551
-0
lines changed

docs/source/deployment/frameworks/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ helm
1111
lws
1212
modal
1313
open-webui
14+
retrieval_augmented_generation
1415
skypilot
1516
streamlit
1617
triton
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
(deployment-retrieval-augmented-generation)=
2+
3+
# Retrieval-Augmented Generation
4+
5+
[Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources.
6+
7+
Here are the integrations:
8+
- vLLM + [langchain](https://github.com/langchain-ai/langchain) + [milvus](https://github.com/milvus-io/milvus)
9+
- vLLM + [llamaindex](https://github.com/run-llama/llama_index) + [milvus](https://github.com/milvus-io/milvus)
10+
11+
## vLLM + langchain
12+
13+
### Prerequisites
14+
15+
- Setup vLLM and langchain environment
16+
17+
```console
18+
pip install -U vllm \
19+
langchain_milvus langchain_openai \
20+
langchain_community beautifulsoup4 \
21+
langchain-text-splitters
22+
```
23+
24+
### Deploy
25+
26+
- Start the vLLM server with the supported embedding model, e.g.
27+
28+
```console
29+
# Start embedding service (port 8000)
30+
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
31+
```
32+
33+
- Start the vLLM server with the supported chat completion model, e.g.
34+
35+
```console
36+
# Start chat service (port 8001)
37+
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
38+
```
39+
40+
- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_langchain.py>
41+
42+
- Run the script
43+
44+
```python
45+
python retrieval_augmented_generation_with_langchain.py
46+
```
47+
48+
## vLLM + llamaindex
49+
50+
### Prerequisites
51+
52+
- Setup vLLM and llamaindex environment
53+
54+
```console
55+
pip install vllm \
56+
llama-index llama-index-readers-web \
57+
llama-index-llms-openai-like \
58+
llama-index-embeddings-openai-like \
59+
llama-index-vector-stores-milvus \
60+
```
61+
62+
### Deploy
63+
64+
- Start the vLLM server with the supported embedding model, e.g.
65+
66+
```console
67+
# Start embedding service (port 8000)
68+
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
69+
```
70+
71+
- Start the vLLM server with the supported chat completion model, e.g.
72+
73+
```console
74+
# Start chat service (port 8001)
75+
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
76+
```
77+
78+
- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_llamaindex.py>
79+
80+
- Run the script
81+
82+
```python
83+
python retrieval_augmented_generation_with_llamaindex.py
84+
```
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Retrieval Augmented Generation (RAG) Implementation with Langchain
4+
==================================================================
5+
6+
This script demonstrates a RAG implementation using LangChain, Milvus
7+
and vLLM. RAG enhances LLM responses by retrieving relevant context
8+
from a document collection.
9+
10+
Features:
11+
- Web content loading and chunking
12+
- Vector storage with Milvus
13+
- Embedding generation with vLLM
14+
- Question answering with context
15+
16+
Prerequisites:
17+
1. Install dependencies:
18+
pip install -U vllm \
19+
langchain_milvus langchain_openai \
20+
langchain_community beautifulsoup4 \
21+
langchain-text-splitters
22+
23+
2. Start services:
24+
# Start embedding service (port 8000)
25+
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
26+
27+
# Start chat service (port 8001)
28+
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
29+
30+
Usage:
31+
python retrieval_augmented_generation_with_langchain.py
32+
33+
Notes:
34+
- Ensure both vLLM services are running before executing
35+
- Default ports: 8000 (embedding), 8001 (chat)
36+
- First run may take time to download models
37+
"""
38+
39+
import argparse
40+
from argparse import Namespace
41+
from typing import Any
42+
43+
from langchain_community.document_loaders import WebBaseLoader
44+
from langchain_core.documents import Document
45+
from langchain_core.output_parsers import StrOutputParser
46+
from langchain_core.prompts import PromptTemplate
47+
from langchain_core.runnables import RunnablePassthrough
48+
from langchain_milvus import Milvus
49+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
50+
from langchain_text_splitters import RecursiveCharacterTextSplitter
51+
52+
53+
def load_and_split_documents(config: dict[str, Any]):
54+
"""
55+
Load and split documents from web URL
56+
"""
57+
try:
58+
loader = WebBaseLoader(web_paths=(config["url"], ))
59+
docs = loader.load()
60+
61+
text_splitter = RecursiveCharacterTextSplitter(
62+
chunk_size=config["chunk_size"],
63+
chunk_overlap=config["chunk_overlap"],
64+
)
65+
return text_splitter.split_documents(docs)
66+
except Exception as e:
67+
print(f"Error loading document from {config['url']}: {str(e)}")
68+
raise
69+
70+
71+
def init_vectorstore(config: dict[str, Any], documents: list[Document]):
72+
"""
73+
Initialize vector store with documents
74+
"""
75+
return Milvus.from_documents(
76+
documents=documents,
77+
embedding=OpenAIEmbeddings(
78+
model=config["embedding_model"],
79+
openai_api_key=config["vllm_api_key"],
80+
openai_api_base=config["vllm_embedding_endpoint"],
81+
),
82+
connection_args={"uri": config["uri"]},
83+
drop_old=True,
84+
)
85+
86+
87+
def init_llm(config: dict[str, Any]):
88+
"""
89+
Initialize llm
90+
"""
91+
return ChatOpenAI(
92+
model=config["chat_model"],
93+
openai_api_key=config["vllm_api_key"],
94+
openai_api_base=config["vllm_chat_endpoint"],
95+
)
96+
97+
98+
def get_qa_prompt():
99+
"""
100+
Get question answering prompt template
101+
"""
102+
template = """You are an assistant for question-answering tasks.
103+
Use the following pieces of retrieved context to answer the question.
104+
If you don't know the answer, just say that you don't know.
105+
Use three sentences maximum and keep the answer concise.
106+
Question: {question}
107+
Context: {context}
108+
Answer:
109+
"""
110+
return PromptTemplate.from_template(template)
111+
112+
113+
def format_docs(docs: list[Document]):
114+
"""
115+
Format documents for prompt
116+
"""
117+
return "\n\n".join(doc.page_content for doc in docs)
118+
119+
120+
def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate):
121+
"""
122+
Set up question answering chain
123+
"""
124+
return ({
125+
"context": retriever | format_docs,
126+
"question": RunnablePassthrough(),
127+
}
128+
| prompt
129+
| llm
130+
| StrOutputParser())
131+
132+
133+
def get_parser() -> argparse.ArgumentParser:
134+
"""
135+
Parse command line arguments
136+
"""
137+
parser = argparse.ArgumentParser(description='RAG with vLLM and langchain')
138+
139+
# Add command line arguments
140+
parser.add_argument('--vllm-api-key',
141+
default="EMPTY",
142+
help='API key for vLLM compatible services')
143+
parser.add_argument('--vllm-embedding-endpoint',
144+
default="http://localhost:8000/v1",
145+
help='Base URL for embedding service')
146+
parser.add_argument('--vllm-chat-endpoint',
147+
default="http://localhost:8001/v1",
148+
help='Base URL for chat service')
149+
parser.add_argument('--uri',
150+
default="./milvus.db",
151+
help='URI for Milvus database')
152+
parser.add_argument(
153+
'--url',
154+
default=("https://docs.vllm.ai/en/latest/getting_started/"
155+
"quickstart.html"),
156+
help='URL of the document to process')
157+
parser.add_argument('--embedding-model',
158+
default="ssmits/Qwen2-7B-Instruct-embed-base",
159+
help='Model name for embeddings')
160+
parser.add_argument('--chat-model',
161+
default="qwen/Qwen1.5-0.5B-Chat",
162+
help='Model name for chat')
163+
parser.add_argument('-i',
164+
'--interactive',
165+
action='store_true',
166+
help='Enable interactive Q&A mode')
167+
parser.add_argument('-k',
168+
'--top-k',
169+
type=int,
170+
default=3,
171+
help='Number of top results to retrieve')
172+
parser.add_argument('-c',
173+
'--chunk-size',
174+
type=int,
175+
default=1000,
176+
help='Chunk size for document splitting')
177+
parser.add_argument('-o',
178+
'--chunk-overlap',
179+
type=int,
180+
default=200,
181+
help='Chunk overlap for document splitting')
182+
183+
return parser
184+
185+
186+
def init_config(args: Namespace):
187+
"""
188+
Initialize configuration settings from command line arguments
189+
"""
190+
191+
return {
192+
"vllm_api_key": args.vllm_api_key,
193+
"vllm_embedding_endpoint": args.vllm_embedding_endpoint,
194+
"vllm_chat_endpoint": args.vllm_chat_endpoint,
195+
"uri": args.uri,
196+
"embedding_model": args.embedding_model,
197+
"chat_model": args.chat_model,
198+
"url": args.url,
199+
"chunk_size": args.chunk_size,
200+
"chunk_overlap": args.chunk_overlap,
201+
"top_k": args.top_k
202+
}
203+
204+
205+
def main():
206+
# Parse command line arguments
207+
args = get_parser().parse_args()
208+
209+
# Initialize configuration
210+
config = init_config(args)
211+
212+
# Load and split documents
213+
documents = load_and_split_documents(config)
214+
215+
# Initialize vector store and retriever
216+
vectorstore = init_vectorstore(config, documents)
217+
retriever = vectorstore.as_retriever(search_kwargs={"k": config["top_k"]})
218+
219+
# Initialize llm and prompt
220+
llm = init_llm(config)
221+
prompt = get_qa_prompt()
222+
223+
# Set up QA chain
224+
qa_chain = create_qa_chain(retriever, llm, prompt)
225+
226+
# Interactive mode
227+
if args.interactive:
228+
print("\nWelcome to Interactive Q&A System!")
229+
print("Enter 'q' or 'quit' to exit.")
230+
231+
while True:
232+
question = input("\nPlease enter your question: ")
233+
if question.lower() in ['q', 'quit']:
234+
print("\nThank you for using! Goodbye!")
235+
break
236+
237+
output = qa_chain.invoke(question)
238+
print(output)
239+
else:
240+
# Default single question mode
241+
question = ("How to install vLLM?")
242+
output = qa_chain.invoke(question)
243+
print("-" * 50)
244+
print(output)
245+
print("-" * 50)
246+
247+
248+
if __name__ == "__main__":
249+
main()

0 commit comments

Comments
 (0)