1+ import os
2+
3+ from openai import OpenAI
4+
5+ from ..base import VannaBase
6+
7+
8+ class Cohere_Chat (VannaBase ):
9+ def __init__ (self , client = None , config = None ):
10+ VannaBase .__init__ (self , config = config )
11+
12+ # default parameters - can be overridden using config
13+ self .temperature = 0.2 # Lower temperature for more precise SQL generation
14+ self .model = "command-a-03-2025" # Cohere's default model
15+
16+ if config is not None :
17+ if "temperature" in config :
18+ self .temperature = config ["temperature" ]
19+ if "model" in config :
20+ self .model = config ["model" ]
21+
22+ if client is not None :
23+ self .client = client
24+ return
25+
26+ # Check for API key in environment variable
27+ api_key = os .getenv ("COHERE_API_KEY" )
28+
29+ # Check for API key in config
30+ if config is not None and "api_key" in config :
31+ api_key = config ["api_key" ]
32+
33+ # Validate API key
34+ if not api_key :
35+ raise ValueError ("Cohere API key is required. Please provide it via config or set the COHERE_API_KEY environment variable." )
36+
37+ # Initialize client with validated API key
38+ self .client = OpenAI (
39+ base_url = "https://api.cohere.ai/compatibility/v1" ,
40+ api_key = api_key ,
41+ )
42+
43+ def system_message (self , message : str ) -> any :
44+ return {"role" : "developer" , "content" : message } # Cohere uses 'developer' for system role
45+
46+ def user_message (self , message : str ) -> any :
47+ return {"role" : "user" , "content" : message }
48+
49+ def assistant_message (self , message : str ) -> any :
50+ return {"role" : "assistant" , "content" : message }
51+
52+ def submit_prompt (self , prompt , ** kwargs ) -> str :
53+ if prompt is None :
54+ raise Exception ("Prompt is None" )
55+
56+ if len (prompt ) == 0 :
57+ raise Exception ("Prompt is empty" )
58+
59+ # Count the number of tokens in the message log
60+ # Use 4 as an approximation for the number of characters per token
61+ num_tokens = 0
62+ for message in prompt :
63+ num_tokens += len (message ["content" ]) / 4
64+
65+ # Use model from kwargs, config, or default
66+ model = kwargs .get ("model" , self .model )
67+ if self .config is not None and "model" in self .config and model == self .model :
68+ model = self .config ["model" ]
69+
70+ print (f"Using model { model } for { num_tokens } tokens (approx)" )
71+ try :
72+ response = self .client .chat .completions .create (
73+ model = model ,
74+ messages = prompt ,
75+ temperature = self .temperature ,
76+ )
77+
78+ # Check if response has expected structure
79+ if not response or not hasattr (response , 'choices' ) or not response .choices :
80+ raise ValueError ("Received empty or malformed response from API" )
81+
82+ if not response .choices [0 ] or not hasattr (response .choices [0 ], 'message' ):
83+ raise ValueError ("Response is missing expected 'message' field" )
84+
85+ if not hasattr (response .choices [0 ].message , 'content' ):
86+ raise ValueError ("Response message is missing expected 'content' field" )
87+
88+ return response .choices [0 ].message .content
89+
90+ except Exception as e :
91+ # Log the error and raise a more informative exception
92+ error_msg = f"Error processing Cohere chat response: { str (e )} "
93+ print (error_msg )
94+ raise Exception (error_msg )
0 commit comments