Skip to content

Commit 2526ece

Browse files
committed
Use one httpx.AsyncClient for all connections
1 parent 0cd2b9a commit 2526ece

File tree

3 files changed

+90
-95
lines changed

3 files changed

+90
-95
lines changed

constants.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,10 @@
1515
# Public endpoints that don't require authentication
1616
PUBLIC_ENDPOINTS = ["/api/v1/models"]
1717

18-
# Use httpx for proxy
19-
HTTPX_ENDPOINTS = ["/api/v1/generation", "/api/v1/models"]
20-
2118
MODELS_ENDPOINTS = ["/api/v1/models"]
2219

2320
# Use openai for proxy
24-
OPENAI_ENDPOINTS = ["/api/v1/completions", "/api/v1/chat/completions"]
21+
OPENAI_ENDPOINTS = ["/api/v1/chat/completions"]
2522

2623
# Read line by line
2724
COMPLETION_ENDPOINTS = ["/api/v1/completions", "/api/v1/chat/completions"]

main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
from fastapi import FastAPI
99

1010
from config import config, logger
11-
from routes import router
11+
from routes import router, lifespan
1212
from utils import get_local_ip
1313

1414
# Create FastAPI app
1515
app = FastAPI(
1616
title="OpenRouter API Proxy",
1717
description="Proxies requests to OpenRouter API and rotates API keys to bypass rate limits",
1818
version="1.0.0",
19+
lifespan=lifespan,
1920
)
2021

2122
# Include routes
@@ -39,4 +40,4 @@
3940
log_config["loggers"]["uvicorn.access"]["level"] = http_log_level
4041
logger.info("HTTP access log level set to %s", http_log_level)
4142

42-
uvicorn.run(app, host=host, port=port, log_config=log_config)
43+
uvicorn.run("main:app", host=host, port=port, log_config=log_config, reload=True)

routes.py

Lines changed: 86 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
"""
55

66
import json
7+
from contextlib import asynccontextmanager
78
from typing import Optional, Dict, Any, AsyncGenerator
89

910
import httpx
10-
from fastapi import APIRouter, Request, Header, HTTPException
11+
from fastapi import APIRouter, Request, Header, HTTPException, FastAPI
1112
from fastapi.responses import StreamingResponse, Response
1213
from openai import AsyncOpenAI, APIError
1314

1415
from config import config, logger
1516
from constants import (
1617
OPENROUTER_BASE_URL,
1718
PUBLIC_ENDPOINTS,
18-
HTTPX_ENDPOINTS,
1919
OPENAI_ENDPOINTS,
2020
COMPLETION_ENDPOINTS,
2121
MODELS_ENDPOINTS,
@@ -37,20 +37,30 @@
3737
)
3838

3939

40-
# Function to create OpenAI client with the right API key
41-
async def get_openai_client(api_key: str) -> AsyncOpenAI:
42-
"""Create an OpenAI client with the specified API key."""
43-
client_params = {
44-
"api_key": api_key,
45-
"base_url": OPENROUTER_BASE_URL
46-
}
47-
40+
@asynccontextmanager
41+
async def lifespan(app_: FastAPI):
42+
client_kwargs = {"timeout": 60.0} # Increase default timeout
4843
# Add proxy configuration if enabled
4944
if config.get("requestProxy", {}).get("enabled", False):
5045
proxy_url = config["requestProxy"]["url"]
51-
client_params["http_client"] = httpx.AsyncClient(proxy=proxy_url)
52-
logger.info("Using proxy for OpenAI client: %s", proxy_url)
46+
client_kwargs["proxy"] = proxy_url
47+
logger.info("Using proxy for httpx client: %s", proxy_url)
48+
app_.state.http_client = httpx.AsyncClient(**client_kwargs)
49+
yield
50+
await app_.state.http_client.aclose()
51+
52+
53+
async def get_async_client(request: Request) -> httpx.AsyncClient:
54+
return request.app.state.http_client
5355

56+
57+
async def get_openai_client(api_key: str, request: Request) -> AsyncOpenAI:
58+
"""Create an OpenAI client with the specified API key."""
59+
client_params = {
60+
"api_key": api_key,
61+
"base_url": OPENROUTER_BASE_URL,
62+
"http_client": await get_async_client(request)
63+
}
5464
return AsyncOpenAI(**client_params)
5565

5666

@@ -66,7 +76,6 @@ async def proxy_endpoint(
6676
"""
6777
is_public = any(f"/api/v1{path}".startswith(ep) for ep in PUBLIC_ENDPOINTS)
6878
is_completion = any(f"/api/v1{path}".startswith(ep) for ep in COMPLETION_ENDPOINTS)
69-
is_httpx = any(f"/api/v1{path}".startswith(ep) for ep in HTTPX_ENDPOINTS)
7079
is_openai = any(f"/api/v1{path}".startswith(ep) for ep in OPENAI_ENDPOINTS)
7180

7281
# Verify authorization for non-public endpoints
@@ -76,8 +85,8 @@ async def proxy_endpoint(
7685
# Log the full request URL including query parameters
7786
full_url = str(request.url).replace(str(request.base_url), "/")
7887
logger.info(
79-
"Proxying request to %s (Public: %s, HTTPX: %s, Completion: %s, OpenAI: %s)",
80-
full_url, is_public, is_httpx, is_completion, is_openai
88+
"Proxying request to %s (Public: %s, Completion: %s, OpenAI: %s)",
89+
full_url, is_public, is_completion, is_openai
8190
)
8291

8392
# Parse request body (if any)
@@ -115,19 +124,11 @@ async def proxy_endpoint(
115124
logger.debug("Could not parse request body: %s", str(e))
116125
request_body = None
117126

118-
# For models, non-OpenAI-compatible endpoints or requests with model-specific parameters, fall back to httpx
119-
if is_httpx or not is_openai:
120-
return await proxy_with_httpx(request, path, api_key, is_stream, is_completion)
121-
122-
# For OpenAI-compatible endpoints, use the OpenAI library
123127
try:
124-
# Create an OpenAI client
125-
client = await get_openai_client(api_key)
126-
127-
# Process based on the endpoint
128+
# For OpenAI-compatible endpoints, use the OpenAI library
128129
if is_openai:
129130
return await handle_completions(
130-
client, request, request_body, api_key, is_stream
131+
request, request_body, api_key, is_stream
131132
)
132133
else:
133134
# Fallback for other endpoints
@@ -143,7 +144,6 @@ async def proxy_endpoint(
143144

144145

145146
async def handle_completions(
146-
client: AsyncOpenAI,
147147
request: Request,
148148
request_body: Dict[str, Any],
149149
api_key: str,
@@ -173,6 +173,9 @@ async def handle_completions(
173173
if param in completion_args:
174174
extra_body[param] = completion_args.pop(param)
175175

176+
# Create an OpenAI client
177+
client = await get_openai_client(api_key, request)
178+
176179
# Create a properly formatted request to the OpenAI API
177180
if is_stream:
178181
logger.info("Making streaming chat completion request")
@@ -194,7 +197,8 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
194197
# Send the end marker
195198
yield b"data: [DONE]\n\n"
196199
except APIError as err:
197-
logger.error("Error in streaming response: %s", err)
200+
logger.error("Error in streaming response %s: %s", err.code, err)
201+
logger.debug("Error body: %s", err.body)
198202
# Check if this is a rate limit error
199203
if api_key:
200204
has_rate_limit_error_, reset_time_ms_ = check_rate_limit_chat(err)
@@ -237,6 +241,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
237241
code = 500
238242
detail = f"Error processing chat completion: {str(e)}"
239243
if isinstance(e, APIError):
244+
logger.debug("Error body: %s", e.body)
240245
# Check if this is a rate limit error
241246
if api_key:
242247
has_rate_limit_error, reset_time_ms = check_rate_limit_chat(e)
@@ -247,9 +252,8 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
247252
# Try again with a new key
248253
new_api_key = await key_manager.get_next_key()
249254
if new_api_key:
250-
new_client = await get_openai_client(new_api_key)
251255
return await handle_completions(
252-
new_client, request, request_body, new_api_key, is_stream
256+
request, request_body, new_api_key, is_stream
253257
)
254258
code = e.code or code
255259
detail = e.body or detail
@@ -295,16 +299,8 @@ async def proxy_with_httpx(
295299
is_completion: bool,
296300
) -> Response:
297301
"""Fall back to httpx for endpoints not supported by the OpenAI SDK."""
298-
client_kwargs = {"timeout": 60.0} # Increase default timeout
299302
free_only = (any(f"/api/v1{path}" == ep for ep in MODELS_ENDPOINTS) and
300303
config["openrouter"].get("free_only", False))
301-
302-
# Add proxy configuration if enabled
303-
if config.get("requestProxy", {}).get("enabled", False):
304-
proxy_url = config["requestProxy"]["url"]
305-
client_kwargs["proxy"] = proxy_url
306-
logger.info("Using proxy for httpx client: %s", proxy_url)
307-
308304
headers = {
309305
k: v
310306
for k, v in request.headers.items()
@@ -323,63 +319,64 @@ async def proxy_with_httpx(
323319

324320
if api_key:
325321
req_kwargs["headers"]["Authorization"] = f"Bearer {api_key}"
326-
async with httpx.AsyncClient(**client_kwargs) as client:
327-
try:
328-
openrouter_resp = await client.request(**req_kwargs)
329-
headers = dict(openrouter_resp.headers)
330-
# Content has already been decoded
331-
headers.pop("content-encoding", None)
332-
headers.pop("Content-Encoding", None)
333-
334-
if not is_stream:
335-
body = await openrouter_resp.aread()
336-
await _check_httpx_err(body, api_key)
337-
if free_only:
338-
body = _remove_paid_models(body)
339-
return Response(
340-
content=body,
341-
status_code=openrouter_resp.status_code,
342-
headers=headers,
343-
)
344-
if not api_key and not is_completion:
345-
return StreamingResponse(
346-
openrouter_resp.aiter_bytes(),
347-
status_code=openrouter_resp.status_code,
348-
headers=headers,
349-
)
350-
351-
async def stream_completion():
352-
data = ''
353-
try:
354-
async for line in openrouter_resp.aiter_lines():
355-
if line.startswith("data: "):
356-
data = line[6:] # Get data without 'data: ' prefix
357-
if data == "[DONE]":
358-
yield "data: [DONE]\n\n".encode("utf-8")
359-
else:
360-
# Forward the original data without reformatting
361-
data = line
362-
yield f"{line}\n\n".encode("utf-8")
363-
elif line:
364-
yield f"{line}\n\n".encode("utf-8")
365-
except Exception as err:
366-
logger.error("stream_completion error: %s", err)
367-
await _check_httpx_err(data, api_key)
368322

323+
client = await get_async_client(request)
324+
try:
325+
openrouter_resp = await client.request(**req_kwargs)
326+
headers = dict(openrouter_resp.headers)
327+
# Content has already been decoded
328+
headers.pop("content-encoding", None)
329+
headers.pop("Content-Encoding", None)
330+
331+
if not is_stream:
332+
body = await openrouter_resp.aread()
333+
await _check_httpx_err(body, api_key)
334+
if free_only:
335+
body = _remove_paid_models(body)
336+
return Response(
337+
content=body,
338+
status_code=openrouter_resp.status_code,
339+
headers=headers,
340+
)
341+
if not api_key and not is_completion:
369342
return StreamingResponse(
370-
stream_completion(),
343+
openrouter_resp.aiter_bytes(),
371344
status_code=openrouter_resp.status_code,
372345
headers=headers,
373346
)
374-
except httpx.ConnectError as e:
375-
logger.error("Connection error to OpenRouter: %s", str(e))
376-
raise HTTPException(503, "Unable to connect to OpenRouter API") from e
377-
except httpx.TimeoutException as e:
378-
logger.error("Timeout connecting to OpenRouter: %s", str(e))
379-
raise HTTPException(504, "OpenRouter API request timed out") from e
380-
except Exception as e:
381-
logger.error("Error proxying request with httpx: %s", str(e))
382-
raise HTTPException(500, f"Proxy error: {str(e)}") from e
347+
348+
async def stream_completion():
349+
data = ''
350+
try:
351+
async for line in openrouter_resp.aiter_lines():
352+
if line.startswith("data: "):
353+
data = line[6:] # Get data without 'data: ' prefix
354+
if data == "[DONE]":
355+
yield "data: [DONE]\n\n".encode("utf-8")
356+
else:
357+
# Forward the original data without reformatting
358+
data = line
359+
yield f"{line}\n\n".encode("utf-8")
360+
elif line:
361+
yield f"{line}\n\n".encode("utf-8")
362+
except Exception as err:
363+
logger.error("stream_completion error: %s", err)
364+
await _check_httpx_err(data, api_key)
365+
366+
return StreamingResponse(
367+
stream_completion(),
368+
status_code=openrouter_resp.status_code,
369+
headers=headers,
370+
)
371+
except httpx.ConnectError as e:
372+
logger.error("Connection error to OpenRouter: %s", str(e))
373+
raise HTTPException(503, "Unable to connect to OpenRouter API") from e
374+
except httpx.TimeoutException as e:
375+
logger.error("Timeout connecting to OpenRouter: %s", str(e))
376+
raise HTTPException(504, "OpenRouter API request timed out") from e
377+
except Exception as e:
378+
logger.error("Error proxying request with httpx: %s", str(e))
379+
raise HTTPException(500, f"Proxy error: {str(e)}") from e
383380

384381

385382
@router.get("/health")

0 commit comments

Comments
 (0)