Skip to content
This repository was archived by the owner on Mar 29, 2026. It is now read-only.

Commit e8e9b43

Browse files
authored
Merge pull request #830 from ai-yann/main
Add Cohere integration with Cohere_Chat and Cohere_Embeddings classes
2 parents 103da3a + fc7efdf commit e8e9b43

5 files changed

Lines changed: 188 additions & 2 deletions

File tree

src/vanna/cohere/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .cohere_chat import Cohere_Chat
2+
from .cohere_embeddings import Cohere_Embeddings

src/vanna/cohere/cohere_chat.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
3+
from openai import OpenAI
4+
5+
from ..base import VannaBase
6+
7+
8+
class Cohere_Chat(VannaBase):
9+
def __init__(self, client=None, config=None):
10+
VannaBase.__init__(self, config=config)
11+
12+
# default parameters - can be overridden using config
13+
self.temperature = 0.2 # Lower temperature for more precise SQL generation
14+
self.model = "command-a-03-2025" # Cohere's default model
15+
16+
if config is not None:
17+
if "temperature" in config:
18+
self.temperature = config["temperature"]
19+
if "model" in config:
20+
self.model = config["model"]
21+
22+
if client is not None:
23+
self.client = client
24+
return
25+
26+
# Check for API key in environment variable
27+
api_key = os.getenv("COHERE_API_KEY")
28+
29+
# Check for API key in config
30+
if config is not None and "api_key" in config:
31+
api_key = config["api_key"]
32+
33+
# Validate API key
34+
if not api_key:
35+
raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.")
36+
37+
# Initialize client with validated API key
38+
self.client = OpenAI(
39+
base_url="https://api.cohere.ai/compatibility/v1",
40+
api_key=api_key,
41+
)
42+
43+
def system_message(self, message: str) -> any:
44+
return {"role": "developer", "content": message} # Cohere uses 'developer' for system role
45+
46+
def user_message(self, message: str) -> any:
47+
return {"role": "user", "content": message}
48+
49+
def assistant_message(self, message: str) -> any:
50+
return {"role": "assistant", "content": message}
51+
52+
def submit_prompt(self, prompt, **kwargs) -> str:
53+
if prompt is None:
54+
raise Exception("Prompt is None")
55+
56+
if len(prompt) == 0:
57+
raise Exception("Prompt is empty")
58+
59+
# Count the number of tokens in the message log
60+
# Use 4 as an approximation for the number of characters per token
61+
num_tokens = 0
62+
for message in prompt:
63+
num_tokens += len(message["content"]) / 4
64+
65+
# Use model from kwargs, config, or default
66+
model = kwargs.get("model", self.model)
67+
if self.config is not None and "model" in self.config and model == self.model:
68+
model = self.config["model"]
69+
70+
print(f"Using model {model} for {num_tokens} tokens (approx)")
71+
try:
72+
response = self.client.chat.completions.create(
73+
model=model,
74+
messages=prompt,
75+
temperature=self.temperature,
76+
)
77+
78+
# Check if response has expected structure
79+
if not response or not hasattr(response, 'choices') or not response.choices:
80+
raise ValueError("Received empty or malformed response from API")
81+
82+
if not response.choices[0] or not hasattr(response.choices[0], 'message'):
83+
raise ValueError("Response is missing expected 'message' field")
84+
85+
if not hasattr(response.choices[0].message, 'content'):
86+
raise ValueError("Response message is missing expected 'content' field")
87+
88+
return response.choices[0].message.content
89+
90+
except Exception as e:
91+
# Log the error and raise a more informative exception
92+
error_msg = f"Error processing Cohere chat response: {str(e)}"
93+
print(error_msg)
94+
raise Exception(error_msg)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
3+
from openai import OpenAI
4+
5+
from ..base import VannaBase
6+
7+
8+
class Cohere_Embeddings(VannaBase):
9+
def __init__(self, client=None, config=None):
10+
VannaBase.__init__(self, config=config)
11+
12+
# Default embedding model
13+
self.model = "embed-multilingual-v3.0"
14+
15+
if config is not None and "model" in config:
16+
self.model = config["model"]
17+
18+
if client is not None:
19+
self.client = client
20+
return
21+
22+
# Check for API key in environment variable
23+
api_key = os.getenv("COHERE_API_KEY")
24+
25+
# Check for API key in config
26+
if config is not None and "api_key" in config:
27+
api_key = config["api_key"]
28+
29+
# Validate API key
30+
if not api_key:
31+
raise ValueError("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable.")
32+
33+
# Initialize client with validated API key
34+
self.client = OpenAI(
35+
base_url="https://api.cohere.ai/compatibility/v1",
36+
api_key=api_key,
37+
)
38+
39+
def generate_embedding(self, data: str, **kwargs) -> list[float]:
40+
if not data:
41+
raise ValueError("Cannot generate embedding for empty input data")
42+
43+
# Use model from kwargs, config, or default
44+
model = kwargs.get("model", self.model)
45+
if self.config is not None and "model" in self.config and model == self.model:
46+
model = self.config["model"]
47+
48+
try:
49+
embedding = self.client.embeddings.create(
50+
model=model,
51+
input=data,
52+
encoding_format="float", # Ensure we get float values
53+
)
54+
55+
# Check if response has expected structure
56+
if not embedding or not hasattr(embedding, 'data') or not embedding.data:
57+
raise ValueError("Received empty or malformed embedding response from API")
58+
59+
if not embedding.data[0] or not hasattr(embedding.data[0], 'embedding'):
60+
raise ValueError("Embedding response is missing expected 'embedding' field")
61+
62+
if not embedding.data[0].embedding:
63+
raise ValueError("Received empty embedding vector")
64+
65+
return embedding.data[0].embedding
66+
67+
except Exception as e:
68+
# Log the error and raise a more informative exception
69+
error_msg = f"Error generating embedding with Cohere: {str(e)}"
70+
print(error_msg)
71+
raise Exception(error_msg)

tests/test_imports.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
2-
31
def test_regular_imports():
42
from vanna.anthropic.anthropic_chat import Anthropic_Chat
53
from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore
64
from vanna.base.base import VannaBase
75
from vanna.bedrock.bedrock_converse import Bedrock_Converse
86
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
7+
from vanna.cohere.cohere_chat import Cohere_Chat
8+
from vanna.cohere.cohere_embeddings import Cohere_Embeddings
99
from vanna.faiss.faiss import FAISS
1010
from vanna.google.bigquery_vector import BigQuery_VectorStore
1111
from vanna.google.gemini_chat import GoogleGeminiChat
@@ -40,6 +40,7 @@ def test_shortcut_imports():
4040
from vanna.azuresearch import AzureAISearch_VectorStore
4141
from vanna.base import VannaBase
4242
from vanna.chromadb import ChromaDB_VectorStore
43+
from vanna.cohere import Cohere_Chat, Cohere_Embeddings
4344
from vanna.faiss import FAISS
4445
from vanna.hf import Hf
4546
from vanna.marqo import Marqo_VectorStore

tests/test_vanna.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
from vanna.anthropic.anthropic_chat import Anthropic_Chat
4+
from vanna.cohere.cohere_chat import Cohere_Chat
45
from vanna.google import GoogleGeminiChat
56
from vanna.mistral.mistral import Mistral
67
from vanna.openai.openai_chat import OpenAI_Chat
@@ -227,6 +228,23 @@ def test_vn_gemini():
227228
df = vn_gemini.run_sql(sql)
228229
assert len(df) == 9
229230

231+
class VannaCohere(VannaDB_VectorStore, Cohere_Chat):
232+
def __init__(self, config=None):
233+
VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
234+
Cohere_Chat.__init__(self, config=config)
235+
236+
try:
237+
COHERE_API_KEY = os.environ['COHERE_API_KEY']
238+
vn_cohere = VannaCohere(config={'api_key': COHERE_API_KEY, 'model': 'command-a-03-2025'})
239+
vn_cohere.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
240+
241+
def test_vn_cohere():
242+
sql = vn_cohere.generate_sql("What are the top 10 customers by sales?")
243+
df = vn_cohere.run_sql(sql)
244+
assert len(df) == 10
245+
except KeyError:
246+
print("Skipping Cohere tests - COHERE_API_KEY not found in environment variables")
247+
230248
def test_training_plan():
231249
vn_dummy = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)
232250

0 commit comments

Comments
 (0)