Skip to content
This repository was archived by the owner on Mar 29, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions src/vanna/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ def __init__(self, config=None):
if config and "embedding_function" in config:
self.embedding_function = config.get("embedding_function")
else:
from sentence_transformers import SentenceTransformer
self.embedding_function = SentenceTransformer("sentence-transformers/all-MiniLM-l6-v2")
from langchain_huggingface import HuggingFaceEmbeddings
self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

self.sql_vectorstore = PGVector(
self.sql_collection = PGVector(
embeddings=self.embedding_function,
collection_name="sql",
connection=self.connection_string,
)
self.ddl_vectorstore = PGVector(
self.ddl_collection = PGVector(
embeddings=self.embedding_function,
collection_name="ddl",
connection=self.connection_string,
)
self.documentation_vectorstore = PGVector(
self.documentation_collection = PGVector(
embeddings=self.embedding_function,
collection_name="documentation",
connection=self.connection_string,
Expand Down Expand Up @@ -94,16 +94,16 @@ def get_collection(self, collection_name):
case _:
raise ValueError("Specified collection does not exist.")

async def get_similar_question_sql(self, question: str) -> list:
def get_similar_question_sql(self, question: str) -> list:
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
return [ast.literal_eval(document.page_content) for document in documents]

async def get_related_ddl(self, question: str, **kwargs) -> list:
documents = await self.ddl_collection.similarity_search(query=question, k=self.n_results)
def get_related_ddl(self, question: str, **kwargs) -> list:
documents = self.ddl_collection.similarity_search(query=question, k=self.n_results)
return [document.page_content for document in documents]

async def get_related_documentation(self, question: str, **kwargs) -> list:
documents = await self.documentation_collection.similarity_search(query=question, k=self.n_results)
def get_related_documentation(self, question: str, **kwargs) -> list:
documents = self.documentation_collection.similarity_search(query=question, k=self.n_results)
return [document.page_content for document in documents]

def train(
Expand Down Expand Up @@ -251,15 +251,3 @@ def remove_collection(self, collection_name: str) -> bool:

def generate_embedding(self, *args, **kwargs):
pass

def submit_prompt(self, *args, **kwargs):
Comment thread
edlouth marked this conversation as resolved.
pass

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}
39 changes: 32 additions & 7 deletions tests/test_pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,47 @@
from dotenv import load_dotenv

# from vanna.pgvector import PG_VectorStore
# from vanna.openai import OpenAI_Chat

# assume .env file placed next to file with provided env vars
load_dotenv()

# Removing thiese tests for now until the dependencies are sorted out
# def get_vanna_connection_string():
# server = os.environ.get("PG_SERVER")
# driver = "psycopg"
# port = 5434
# port = os.environ.get("PG_PORT", 5432)
# database = os.environ.get("PG_DATABASE")
# username = os.environ.get("PG_USERNAME")
# password = os.environ.get("PG_PASSWORD")

# return f"postgresql+psycopg://{username}:{password}@{server}:{port}/{database}"
# def test_pgvector_e2e():
# # configure Vanna to use OpenAI and PGVector
# class VannaCustom(PG_VectorStore, OpenAI_Chat):
# def __init__(self, config=None):
# PG_VectorStore.__init__(self, config=config)
# OpenAI_Chat.__init__(self, config=config)

# vn = VannaCustom(config={
# 'api_key': os.environ['OPENAI_API_KEY'],
# 'model': 'gpt-3.5-turbo',
# "connection_string": get_vanna_connection_string(),
# })

# # connect to SQLite database
# vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

# # train Vanna on DDLs
# df_ddl = vn.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
# for ddl in df_ddl['sql'].to_list():
# vn.train(ddl=ddl)
# assert len(vn.get_related_ddl("dummy question")) == 10 # assume 10 DDL chunks are retrieved by default

# question = "What are the top 7 customers by sales?"
# sql = vn.generate_sql(question)
# df = vn.run_sql(sql)
# assert len(df) == 7

# # test if Vanna can generate an answer
# answer = vn.ask(question)
# assert answer is not None

# def test_pgvector():
# connection_string = get_vanna_connection_string()
# pgclient = PG_VectorStore(config={"connection_string": connection_string})
# assert pgclient is not None