2222from optillm .rto import round_trip_optimization
2323from optillm .self_consistency import advanced_self_consistency_approach
2424from optillm .pvg import inference_time_pv_game
25- from optillm .z3_solver import Z3SolverSystem
25+ from optillm .z3_solver import Z3SymPySolverSystem
2626from optillm .rstar import RStar
2727from optillm .cot_reflection import cot_reflection
2828from optillm .plansearch import plansearch
4444# Initialize Flask app
4545app = Flask (__name__ )
4646
47- # OpenAI, Azure, or LiteLLM API configuration
48- if os .environ .get ("OPENAI_API_KEY" ):
49- API_KEY = os .environ .get ("OPENAI_API_KEY" )
50- default_client = OpenAI (api_key = API_KEY )
51- elif os .environ .get ("AZURE_OPENAI_API_KEY" ):
52- API_KEY = os .environ .get ("AZURE_OPENAI_API_KEY" )
53- API_VERSION = os .environ .get ("AZURE_API_VERSION" )
54- AZURE_ENDPOINT = os .environ .get ("AZURE_API_BASE" )
55- if API_KEY is not None :
56- default_client = AzureOpenAI (
57- api_key = API_KEY ,
58- api_version = API_VERSION ,
59- azure_endpoint = AZURE_ENDPOINT ,
60- )
47+ def get_config ():
48+ API_KEY = None
49+ # OpenAI, Azure, or LiteLLM API configuration
50+ if os .environ .get ("OPENAI_API_KEY" ):
51+ API_KEY = os .environ .get ("OPENAI_API_KEY" )
52+ default_client = OpenAI (api_key = API_KEY )
53+ elif os .environ .get ("AZURE_OPENAI_API_KEY" ):
54+ API_KEY = os .environ .get ("AZURE_OPENAI_API_KEY" )
55+ API_VERSION = os .environ .get ("AZURE_API_VERSION" )
56+ AZURE_ENDPOINT = os .environ .get ("AZURE_API_BASE" )
57+ if API_KEY is not None :
58+ default_client = AzureOpenAI (
59+ api_key = API_KEY ,
60+ api_version = API_VERSION ,
61+ azure_endpoint = AZURE_ENDPOINT ,
62+ )
63+ else :
64+ from azure .identity import DefaultAzureCredential , get_bearer_token_provider
65+ azure_credential = DefaultAzureCredential ()
66+ token_provider = get_bearer_token_provider (azure_credential , "https://cognitiveservices.azure.com/.default" )
67+ default_client = AzureOpenAI (
68+ api_version = API_VERSION ,
69+ azure_endpoint = AZURE_ENDPOINT ,
70+ azure_ad_token_provider = token_provider
71+ )
6172 else :
62- from azure .identity import DefaultAzureCredential , get_bearer_token_provider
63- azure_credential = DefaultAzureCredential ()
64- token_provider = get_bearer_token_provider (azure_credential , "https://cognitiveservices.azure.com/.default" )
65- default_client = AzureOpenAI (
66- api_version = API_VERSION ,
67- azure_endpoint = AZURE_ENDPOINT ,
68- azure_ad_token_provider = token_provider
69- )
70- else :
71- default_client = LiteLLMWrapper ()
73+ default_client = LiteLLMWrapper ()
74+ return default_client , API_KEY
7275
7376# Server configuration
7477server_config = {
@@ -156,7 +159,7 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode
156159 elif approach == 'rto' :
157160 return round_trip_optimization (system_prompt , initial_query , client , model )
158161 elif approach == 'z3' :
159- z3_solver = Z3SolverSystem (system_prompt , client , model )
162+ z3_solver = Z3SymPySolverSystem (system_prompt , client , model )
160163 return z3_solver .process_query (initial_query )
161164 elif approach == "self_consistency" :
162165 return advanced_self_consistency_approach (system_prompt , initial_query , client , model )
@@ -263,6 +266,14 @@ def check_api_key():
263266def proxy ():
264267 logger .info ('Received request to /v1/chat/completions' )
265268 data = request .get_json ()
269+ auth_header = request .headers .get ("Authorization" )
270+ bearer_token = ""
271+
272+ if auth_header and auth_header .startswith ("Bearer " ):
273+ # Extract the bearer token
274+ bearer_token = auth_header .split ("Bearer " )[1 ].strip ()
275+ logger .debug (f"Intercepted Bearer Token: { bearer_token } " )
276+
266277 logger .debug (f'Request data: { data } ' )
267278
268279 stream = data .get ('stream' , False )
@@ -281,15 +292,20 @@ def proxy():
281292 model = f"{ optillm_approach } -{ model } "
282293
283294 base_url = server_config ['base_url' ]
284-
285- if base_url != "" :
286- client = OpenAI (api_key = API_KEY , base_url = base_url )
287- else :
288- client = default_client
295+ default_client , api_key = get_config ()
289296
290297 operation , approaches , model = parse_combined_approach (model , known_approaches , plugin_approaches )
291298 logger .info (f'Using approach(es) { approaches } , operation { operation } , with model { model } ' )
292299
300+ if bearer_token != "" and bearer_token .startswith ("sk-" ) and model .startswith ("gpt" ):
301+ api_key = bearer_token
302+ if base_url != "" :
303+ client = OpenAI (api_key = api_key , base_url = base_url )
304+ else :
305+ client = OpenAI (api_key = api_key )
306+ else :
307+ client = default_client
308+
293309 try :
294310 if operation == 'SINGLE' :
295311 final_response , completion_tokens = execute_single_approach (approaches [0 ], system_prompt , initial_query , client , model )
@@ -342,7 +358,7 @@ def proxy():
342358@app .route ('/v1/models' , methods = ['GET' ])
343359def proxy_models ():
344360 logger .info ('Received request to /v1/models' )
345-
361+ default_client , API_KEY = get_config ()
346362 try :
347363 if server_config ['base_url' ]:
348364 client = OpenAI (api_key = API_KEY , base_url = server_config ['base_url' ])
0 commit comments