Skip to content

Commit 7ecc44f

Browse files
committed
优化基于负载的令牌并发调度
1 parent cb50ea0 commit 7ecc44f

File tree

5 files changed

+266
-74
lines changed

5 files changed

+266
-74
lines changed

src/api/admin.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
from ..core.config import config
1414
from ..services.token_manager import TokenManager
1515
from ..services.proxy_manager import ProxyManager
16+
from ..services.concurrency_manager import ConcurrencyManager
1617

1718
router = APIRouter()
1819

1920
# Dependency injection
2021
token_manager: TokenManager = None
2122
proxy_manager: ProxyManager = None
2223
db: Database = None
24+
concurrency_manager: Optional[ConcurrencyManager] = None
2325

2426
# Store active admin session tokens (in production, use Redis or database)
2527
active_admin_tokens = set()
@@ -206,12 +208,13 @@ async def _solve_recaptcha_with_api_service(
206208
raise RuntimeError(f"{method} 获取 token 超时")
207209

208210

209-
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database):
211+
def set_dependencies(tm: TokenManager, pm: ProxyManager, database: Database, cm: Optional[ConcurrencyManager] = None):
210212
"""Set service instances"""
211-
global token_manager, proxy_manager, db
213+
global token_manager, proxy_manager, db, concurrency_manager
212214
token_manager = tm
213215
proxy_manager = pm
214216
db = database
217+
concurrency_manager = cm
215218

216219

217220
# ========== Request Models ==========
@@ -441,6 +444,14 @@ async def add_token(
441444
video_concurrency=request.video_concurrency
442445
)
443446

447+
# 热更新并发限制,避免必须重启服务
448+
if concurrency_manager:
449+
await concurrency_manager.reset_token(
450+
new_token.id,
451+
image_concurrency=new_token.image_concurrency,
452+
video_concurrency=new_token.video_concurrency
453+
)
454+
444455
return {
445456
"success": True,
446457
"message": "Token添加成功",
@@ -495,6 +506,16 @@ async def update_token(
495506
video_concurrency=request.video_concurrency
496507
)
497508

509+
# 热更新并发限制,确保管理台修改立即生效
510+
if concurrency_manager:
511+
updated_token = await token_manager.get_token(token_id)
512+
if updated_token:
513+
await concurrency_manager.reset_token(
514+
token_id,
515+
image_concurrency=updated_token.image_concurrency,
516+
video_concurrency=updated_token.video_concurrency
517+
)
518+
498519
return {"success": True, "message": "Token更新成功"}
499520
except Exception as e:
500521
raise HTTPException(status_code=500, detail=str(e))

src/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ async def auto_unban_task():
174174

175175
# Set dependencies
176176
routes.set_generation_handler(generation_handler)
177-
admin.set_dependencies(token_manager, proxy_manager, db)
177+
admin.set_dependencies(token_manager, proxy_manager, db, concurrency_manager)
178178

179179
# Create FastAPI app
180180
app = FastAPI(

src/services/concurrency_manager.py

Lines changed: 108 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ class ConcurrencyManager:
99

1010
def __init__(self):
1111
"""Initialize concurrency manager"""
12-
self._image_concurrency: Dict[int, int] = {} # token_id -> remaining image concurrency
13-
self._video_concurrency: Dict[int, int] = {} # token_id -> remaining video concurrency
12+
# token_id -> max concurrency limit (only stores >0 values, missing means unlimited)
13+
self._image_limits: Dict[int, int] = {}
14+
self._video_limits: Dict[int, int] = {}
15+
# token_id -> current in-flight requests
16+
self._image_inflight: Dict[int, int] = {}
17+
self._video_inflight: Dict[int, int] = {}
1418
self._lock = asyncio.Lock() # Protect concurrent access
1519

1620
async def initialize(self, tokens: list):
@@ -21,11 +25,21 @@ async def initialize(self, tokens: list):
2125
tokens: List of Token objects with image_concurrency and video_concurrency fields
2226
"""
2327
async with self._lock:
28+
self._image_limits.clear()
29+
self._video_limits.clear()
30+
31+
# 初始化时重置 in-flight,避免重启后带入脏状态
32+
self._image_inflight.clear()
33+
self._video_inflight.clear()
34+
2435
for token in tokens:
36+
self._image_inflight[token.id] = 0
37+
self._video_inflight[token.id] = 0
38+
2539
if token.image_concurrency and token.image_concurrency > 0:
26-
self._image_concurrency[token.id] = token.image_concurrency
40+
self._image_limits[token.id] = token.image_concurrency
2741
if token.video_concurrency and token.video_concurrency > 0:
28-
self._video_concurrency[token.id] = token.video_concurrency
42+
self._video_limits[token.id] = token.video_concurrency
2943

3044
debug_logger.log_info(f"Concurrency manager initialized with {len(tokens)} tokens")
3145

@@ -40,13 +54,16 @@ async def can_use_image(self, token_id: int) -> bool:
4054
True if token has available image concurrency, False if concurrency is 0
4155
"""
4256
async with self._lock:
43-
# If not in dict, it means no limit (-1)
44-
if token_id not in self._image_concurrency:
57+
limit = self._image_limits.get(token_id)
58+
# Missing limit means unlimited (-1)
59+
if limit is None:
4560
return True
4661

47-
remaining = self._image_concurrency[token_id]
48-
if remaining <= 0:
49-
debug_logger.log_info(f"Token {token_id} image concurrency exhausted (remaining: {remaining})")
62+
inflight = self._image_inflight.get(token_id, 0)
63+
if inflight >= limit:
64+
debug_logger.log_info(
65+
f"Token {token_id} image concurrency exhausted (inflight: {inflight}/{limit})"
66+
)
5067
return False
5168

5269
return True
@@ -62,13 +79,16 @@ async def can_use_video(self, token_id: int) -> bool:
6279
True if token has available video concurrency, False if concurrency is 0
6380
"""
6481
async with self._lock:
65-
# If not in dict, it means no limit (-1)
66-
if token_id not in self._video_concurrency:
82+
limit = self._video_limits.get(token_id)
83+
# Missing limit means unlimited (-1)
84+
if limit is None:
6785
return True
6886

69-
remaining = self._video_concurrency[token_id]
70-
if remaining <= 0:
71-
debug_logger.log_info(f"Token {token_id} video concurrency exhausted (remaining: {remaining})")
87+
inflight = self._video_inflight.get(token_id, 0)
88+
if inflight >= limit:
89+
debug_logger.log_info(
90+
f"Token {token_id} video concurrency exhausted (inflight: {inflight}/{limit})"
91+
)
7292
return False
7393

7494
return True
@@ -84,15 +104,18 @@ async def acquire_image(self, token_id: int) -> bool:
84104
True if acquired, False if not available
85105
"""
86106
async with self._lock:
87-
if token_id not in self._image_concurrency:
88-
# No limit
89-
return True
107+
limit = self._image_limits.get(token_id)
108+
inflight = self._image_inflight.get(token_id, 0)
90109

91-
if self._image_concurrency[token_id] <= 0:
110+
if limit is not None and inflight >= limit:
92111
return False
93112

94-
self._image_concurrency[token_id] -= 1
95-
debug_logger.log_info(f"Token {token_id} acquired image slot (remaining: {self._image_concurrency[token_id]})")
113+
new_inflight = inflight + 1
114+
self._image_inflight[token_id] = new_inflight
115+
if limit is None:
116+
debug_logger.log_info(f"Token {token_id} acquired image slot (inflight: {new_inflight}, limit: unlimited)")
117+
else:
118+
debug_logger.log_info(f"Token {token_id} acquired image slot (inflight: {new_inflight}/{limit})")
96119
return True
97120

98121
async def acquire_video(self, token_id: int) -> bool:
@@ -106,15 +129,18 @@ async def acquire_video(self, token_id: int) -> bool:
106129
True if acquired, False if not available
107130
"""
108131
async with self._lock:
109-
if token_id not in self._video_concurrency:
110-
# No limit
111-
return True
132+
limit = self._video_limits.get(token_id)
133+
inflight = self._video_inflight.get(token_id, 0)
112134

113-
if self._video_concurrency[token_id] <= 0:
135+
if limit is not None and inflight >= limit:
114136
return False
115137

116-
self._video_concurrency[token_id] -= 1
117-
debug_logger.log_info(f"Token {token_id} acquired video slot (remaining: {self._video_concurrency[token_id]})")
138+
new_inflight = inflight + 1
139+
self._video_inflight[token_id] = new_inflight
140+
if limit is None:
141+
debug_logger.log_info(f"Token {token_id} acquired video slot (inflight: {new_inflight}, limit: unlimited)")
142+
else:
143+
debug_logger.log_info(f"Token {token_id} acquired video slot (inflight: {new_inflight}/{limit})")
118144
return True
119145

120146
async def release_image(self, token_id: int):
@@ -125,9 +151,19 @@ async def release_image(self, token_id: int):
125151
token_id: Token ID
126152
"""
127153
async with self._lock:
128-
if token_id in self._image_concurrency:
129-
self._image_concurrency[token_id] += 1
130-
debug_logger.log_info(f"Token {token_id} released image slot (remaining: {self._image_concurrency[token_id]})")
154+
inflight = self._image_inflight.get(token_id, 0)
155+
if inflight <= 0:
156+
self._image_inflight[token_id] = 0
157+
debug_logger.log_warning(f"Token {token_id} release_image called with inflight=0")
158+
return
159+
160+
new_inflight = inflight - 1
161+
self._image_inflight[token_id] = new_inflight
162+
limit = self._image_limits.get(token_id)
163+
if limit is None:
164+
debug_logger.log_info(f"Token {token_id} released image slot (inflight: {new_inflight}, limit: unlimited)")
165+
else:
166+
debug_logger.log_info(f"Token {token_id} released image slot (inflight: {new_inflight}/{limit})")
131167

132168
async def release_video(self, token_id: int):
133169
"""
@@ -137,9 +173,19 @@ async def release_video(self, token_id: int):
137173
token_id: Token ID
138174
"""
139175
async with self._lock:
140-
if token_id in self._video_concurrency:
141-
self._video_concurrency[token_id] += 1
142-
debug_logger.log_info(f"Token {token_id} released video slot (remaining: {self._video_concurrency[token_id]})")
176+
inflight = self._video_inflight.get(token_id, 0)
177+
if inflight <= 0:
178+
self._video_inflight[token_id] = 0
179+
debug_logger.log_warning(f"Token {token_id} release_video called with inflight=0")
180+
return
181+
182+
new_inflight = inflight - 1
183+
self._video_inflight[token_id] = new_inflight
184+
limit = self._video_limits.get(token_id)
185+
if limit is None:
186+
debug_logger.log_info(f"Token {token_id} released video slot (inflight: {new_inflight}, limit: unlimited)")
187+
else:
188+
debug_logger.log_info(f"Token {token_id} released video slot (inflight: {new_inflight}/{limit})")
143189

144190
async def get_image_remaining(self, token_id: int) -> Optional[int]:
145191
"""
@@ -152,7 +198,11 @@ async def get_image_remaining(self, token_id: int) -> Optional[int]:
152198
Remaining count or None if no limit
153199
"""
154200
async with self._lock:
155-
return self._image_concurrency.get(token_id)
201+
limit = self._image_limits.get(token_id)
202+
if limit is None:
203+
return None
204+
inflight = self._image_inflight.get(token_id, 0)
205+
return max(0, limit - inflight)
156206

157207
async def get_video_remaining(self, token_id: int) -> Optional[int]:
158208
"""
@@ -165,7 +215,21 @@ async def get_video_remaining(self, token_id: int) -> Optional[int]:
165215
Remaining count or None if no limit
166216
"""
167217
async with self._lock:
168-
return self._video_concurrency.get(token_id)
218+
limit = self._video_limits.get(token_id)
219+
if limit is None:
220+
return None
221+
inflight = self._video_inflight.get(token_id, 0)
222+
return max(0, limit - inflight)
223+
224+
async def get_image_inflight(self, token_id: int) -> int:
225+
"""Get current in-flight image request count for token"""
226+
async with self._lock:
227+
return self._image_inflight.get(token_id, 0)
228+
229+
async def get_video_inflight(self, token_id: int) -> int:
230+
"""Get current in-flight video request count for token"""
231+
async with self._lock:
232+
return self._video_inflight.get(token_id, 0)
169233

170234
async def reset_token(self, token_id: int, image_concurrency: int = -1, video_concurrency: int = -1):
171235
"""
@@ -178,13 +242,17 @@ async def reset_token(self, token_id: int, image_concurrency: int = -1, video_co
178242
"""
179243
async with self._lock:
180244
if image_concurrency > 0:
181-
self._image_concurrency[token_id] = image_concurrency
182-
elif token_id in self._image_concurrency:
183-
del self._image_concurrency[token_id]
245+
self._image_limits[token_id] = image_concurrency
246+
elif token_id in self._image_limits:
247+
del self._image_limits[token_id]
184248

185249
if video_concurrency > 0:
186-
self._video_concurrency[token_id] = video_concurrency
187-
elif token_id in self._video_concurrency:
188-
del self._video_concurrency[token_id]
250+
self._video_limits[token_id] = video_concurrency
251+
elif token_id in self._video_limits:
252+
del self._video_limits[token_id]
253+
254+
# 重置时确保存在 in-flight 计数字段
255+
self._image_inflight.setdefault(token_id, 0)
256+
self._video_inflight.setdefault(token_id, 0)
189257

190258
debug_logger.log_info(f"Token {token_id} concurrency reset (image: {image_concurrency}, video: {video_concurrency})")

0 commit comments

Comments
 (0)