Skip to content

Commit 8aeccd9

Browse files
committed
Refactoring
1 parent 5c8867d commit 8aeccd9

File tree

6 files changed

+74
-72
lines changed

6 files changed

+74
-72
lines changed

config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def load_config() -> Dict[str, Any]:
2626
print(f"Error parsing configuration file: {e}")
2727
sys.exit(1)
2828

29+
2930
def setup_logging(config_: Dict[str, Any]) -> logging.Logger:
3031
"""Configure logging based on configuration."""
3132
log_level_str = config_.get("server", {}).get("log_level", "INFO")
@@ -41,6 +42,7 @@ def setup_logging(config_: Dict[str, Any]) -> logging.Logger:
4142

4243
return logger_
4344

45+
4446
# Load configuration
4547
config = load_config()
4648

key_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _mask_key(key: str) -> str:
2020
return "****"
2121
return key[:4] + "****" + key[-4:]
2222

23+
2324
class KeyManager:
2425
"""Manages OpenRouter API keys, including rotation and rate limit handling."""
2526
def __init__(self, keys: List[str], cooldown_seconds: int):

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@
4040
log_config["loggers"]["uvicorn.access"]["level"] = http_log_level
4141
logger.info("HTTP access log level set to %s", http_log_level)
4242

43-
uvicorn.run(app, host=host, port=port, log_config=log_config)
43+
uvicorn.run(app, host=host, port=port, log_config=log_config, timeout_graceful_shutdown=30)

routes.py

Lines changed: 65 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,47 @@ async def get_openai_client(api_key: str, request: Request) -> AsyncOpenAI:
6464
return AsyncOpenAI(**client_params)
6565

6666

67+
async def check_httpx_err(body: str | bytes, api_key: str | None):
68+
# too big for error
69+
if len(body) > 4000 or not api_key:
70+
return
71+
if (isinstance(body, str) and body.startswith("data: ") or (
72+
isinstance(body, bytes) and body.startswith(b"data: "))):
73+
body = body[6:]
74+
has_rate_limit_error, reset_time_ms = await check_rate_limit(body)
75+
if has_rate_limit_error:
76+
logger.warning("Rate limit detected in stream. Disabling key.")
77+
await key_manager.disable_key(api_key, reset_time_ms)
78+
79+
80+
def remove_paid_models(body: bytes) -> bytes:
81+
# {'prompt': '0', 'completion': '0', 'request': '0', 'image': '0', 'web_search': '0', 'internal_reasoning': '0'}
82+
prices = ['prompt', 'completion', 'request', 'image', 'web_search', 'internal_reasoning']
83+
try:
84+
data = json.loads(body)
85+
except (json.JSONDecodeError, ValueError) as e:
86+
logger.warning("Error models deserialize: %s", str(e))
87+
else:
88+
if isinstance(data.get("data"), list):
89+
clear_data = []
90+
for model in data["data"]:
91+
if all(model.get("pricing", {}).get(k, "1") == "0" for k in prices):
92+
clear_data.append(model)
93+
if clear_data:
94+
data["data"] = clear_data
95+
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
96+
return body
97+
98+
99+
def prepare_forward_headers(request: Request) -> dict:
100+
return {
101+
k: v
102+
for k, v in request.headers.items()
103+
if k.lower()
104+
not in ["host", "content-length", "connection", "authorization"]
105+
}
106+
107+
67108
@router.api_route(
68109
"/api/v1{path:path}",
69110
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"],
@@ -89,37 +130,36 @@ async def proxy_endpoint(
89130
full_url, is_public, is_completion, is_openai
90131
)
91132

92-
# Parse request body (if any)
93-
request_body = None
94-
is_stream = False
95133
# Get API key to use
96-
if not is_public:
134+
if is_public:
135+
# For public endpoints, we don't need an API key
136+
api_key = ""
137+
else:
97138
api_key = await key_manager.get_next_key()
98139
if not api_key:
99140
raise HTTPException(status_code=503, detail="No available API keys")
100-
else:
101-
# For public endpoints, we don't need an API key
102-
api_key = ""
141+
142+
# Parse request body (if any)
143+
request_body = None
144+
is_stream = False
103145
try:
104146
body_bytes = await request.body()
105147
if body_bytes:
106148
request_body = json.loads(body_bytes)
107149
is_stream = request_body.get("stream", False)
108150

109-
# Log if this is a streaming request
110-
if is_stream:
111-
logger.info("Detected streaming request")
112-
113151
# Check for model
114152
if is_openai and request.method == "POST":
115153
model = request_body.get("model", "")
116154
if model:
117155
logger.info("Using model: %s", model)
118-
119156
except Exception as e:
120157
logger.debug("Could not parse request body: %s", str(e))
121158
request_body = None
122159

160+
if is_stream:
161+
logger.info("Detected streaming request")
162+
123163
try:
124164
# For OpenAI-compatible endpoints, use the OpenAI library
125165
if is_openai:
@@ -148,12 +188,7 @@ async def handle_completions(
148188
"""Handle chat completions using the OpenAI client."""
149189
try:
150190
# Extract headers to forward
151-
forward_headers = {
152-
k: v
153-
for k, v in request.headers.items()
154-
if k.lower()
155-
not in ["host", "content-length", "connection", "authorization"]
156-
}
191+
forward_headers = prepare_forward_headers(request)
157192

158193
# Create a copy of the request body to modify
159194
completion_args = request_body.copy()
@@ -172,15 +207,13 @@ async def handle_completions(
172207
# Create an OpenAI client
173208
client = await get_openai_client(api_key, request)
174209

175-
# Create a properly formatted request to the OpenAI API
176-
if is_stream:
177-
logger.info("Making streaming chat completion request")
178-
179-
response = await client.chat.completions.create(
180-
**completion_args, extra_headers=forward_headers, extra_body=extra_body, stream=True
181-
)
210+
logger.info(f"Making {'streaming' if is_stream else 'regular'} chat completion request")
211+
response = await client.chat.completions.create(
212+
**completion_args, extra_headers=forward_headers, extra_body=extra_body, stream=is_stream
213+
)
182214

183-
# Handle streaming response
215+
# Handle streaming response
216+
if is_stream:
184217
async def stream_response() -> AsyncGenerator[bytes, None]:
185218
try:
186219
async for chunk in response:
@@ -217,13 +250,8 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
217250
"X-Accel-Buffering": "no",
218251
},
219252
)
220-
# Non-streaming request
221-
logger.info("Making regular chat completion request")
222-
223-
response = await client.chat.completions.create(
224-
**completion_args, extra_headers=forward_headers, extra_body=extra_body
225-
)
226253

254+
# Non-streaming request
227255
result = response.model_dump()
228256
if 'error' in result:
229257
raise APIError(result['error'].get("message", "Error"), None, body=result['error'])
@@ -257,36 +285,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
257285
raise HTTPException(code, detail) from e
258286

259287

260-
async def _check_httpx_err(body: str | bytes, api_key: str | None):
261-
# too big for error
262-
if len(body) > 4000 or not api_key:
263-
return
264-
if (isinstance(body, str) and body.startswith("data: ") or (
265-
isinstance(body, bytes) and body.startswith(b"data: "))):
266-
body = body[6:]
267-
has_rate_limit_error, reset_time_ms = await check_rate_limit(body)
268-
if has_rate_limit_error:
269-
logger.warning("Rate limit detected in stream. Disabling key.")
270-
await key_manager.disable_key(api_key, reset_time_ms)
271-
272-
def _remove_paid_models(body: bytes) -> bytes:
273-
# {'prompt': '0', 'completion': '0', 'request': '0', 'image': '0', 'web_search': '0', 'internal_reasoning': '0'}
274-
prices = ['prompt', 'completion', 'request', 'image', 'web_search', 'internal_reasoning']
275-
try:
276-
data = json.loads(body)
277-
except (json.JSONDecodeError, ValueError) as e:
278-
logger.warning("Error models deserialize: %s", str(e))
279-
else:
280-
if isinstance(data.get("data"), list):
281-
clear_data = []
282-
for model in data["data"]:
283-
if all(model.get("pricing", {}).get(k, "1") == "0" for k in prices):
284-
clear_data.append(model)
285-
if clear_data:
286-
data["data"] = clear_data
287-
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
288-
return body
289-
290288
async def proxy_with_httpx(
291289
request: Request,
292290
path: str,
@@ -297,12 +295,7 @@ async def proxy_with_httpx(
297295
"""Fall back to httpx for endpoints not supported by the OpenAI SDK."""
298296
free_only = (any(f"/api/v1{path}" == ep for ep in MODELS_ENDPOINTS) and
299297
config["openrouter"].get("free_only", False))
300-
headers = {
301-
k: v
302-
for k, v in request.headers.items()
303-
if k.lower()
304-
not in ["host", "content-length", "connection", "authorization"]
305-
}
298+
headers = prepare_forward_headers(request)
306299
req_kwargs = {
307300
"method": request.method,
308301
"url": f"{OPENROUTER_BASE_URL}{path}",
@@ -326,9 +319,9 @@ async def proxy_with_httpx(
326319

327320
if not is_stream:
328321
body = await openrouter_resp.aread()
329-
await _check_httpx_err(body, api_key)
322+
await check_httpx_err(body, api_key)
330323
if free_only:
331-
body = _remove_paid_models(body)
324+
body = remove_paid_models(body)
332325
return Response(
333326
content=body,
334327
status_code=openrouter_resp.status_code,
@@ -357,7 +350,8 @@ async def stream_completion():
357350
yield f"{line}\n\n".encode("utf-8")
358351
except Exception as err:
359352
logger.error("stream_completion error: %s", err)
360-
await _check_httpx_err(data, api_key)
353+
await check_httpx_err(data, api_key)
354+
361355

362356
return StreamingResponse(
363357
stream_completion(),

test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def load_config():
1818
with open("config.yml", encoding="utf-8") as file:
1919
return yaml.safe_load(file)
2020

21+
2122
# Get configuration
2223
config = load_config()
2324
server_config = config["server"]
@@ -37,6 +38,8 @@ def load_config():
3738

3839
MODEL = "deepseek/deepseek-r1:free"
3940
# MODEL = "google/gemini-2.0-pro-exp-02-05:free"
41+
42+
4043
async def test_openrouter_streaming():
4144
"""
4245
Test the OpenRouter proxy with streaming mode.

utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def get_local_ip() -> str:
2828
except Exception:
2929
return "localhost"
3030

31+
3132
async def verify_access_key(
3233
authorization: Optional[str] = Header(None),
3334
) -> bool:
@@ -102,6 +103,7 @@ async def is_google_error(data: str) -> bool:
102103
return True
103104
return False
104105

106+
105107
async def check_rate_limit_chat(err: APIError) -> Tuple[bool, Optional[int]]:
106108
"""
107109
Check for rate limit error.

0 commit comments

Comments
 (0)