Skip to content

Commit 493d1cb

Browse files
lilyz-aiclaude
andcommitted
fix: harden ModelWeightsManager background task reliability
- Hold a strong set reference to each asyncio.Task to prevent GC cancellation - Deduplicate concurrent sync requests for the same hf_repo via _in_progress dict - Surface task exceptions via logger.error in _on_task_done callback - Store ModelWeightsManager as app.state singleton so state persists across requests - Add recover_hf_syncs startup handler to re-trigger syncs after server restart Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
1 parent b14deb6 commit 493d1cb

File tree

4 files changed

+133
-9
lines changed

4 files changed

+133
-9
lines changed

model-engine/model_engine_server/api/app.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,56 @@ def load_redis():
312312
get_or_create_aioredis_pool()
313313

314314

315+
@app.on_event("startup")
316+
def init_model_weights_manager():
317+
from model_engine_server.core.config import infra_config
318+
from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager
319+
from model_engine_server.infra.gateways import (
320+
ABSLLMArtifactGateway,
321+
GCSLLMArtifactGateway,
322+
S3LLMArtifactGateway,
323+
)
324+
325+
provider = infra_config().cloud_provider
326+
if provider == "azure":
327+
gateway = ABSLLMArtifactGateway()
328+
elif provider == "gcp":
329+
gateway = GCSLLMArtifactGateway()
330+
else:
331+
gateway = S3LLMArtifactGateway()
332+
app.state.model_weights_manager = ModelWeightsManager(llm_artifact_gateway=gateway)
333+
334+
335+
@app.on_event("startup")
336+
async def recover_hf_syncs():
337+
"""Re-trigger weight syncs for endpoints that were syncing when server last stopped."""
338+
from model_engine_server.db.base import get_session_async
339+
from model_engine_server.infra.repositories.live_tokenizer_repository import (
340+
SUPPORTED_MODELS_INFO,
341+
)
342+
from sqlalchemy import text
343+
344+
session_factory = get_session_async()
345+
try:
346+
async with session_factory() as session:
347+
result = await session.execute(
348+
text(
349+
"SELECT DISTINCT endpoint_metadata->'_llm'->>'model_name' AS model_name "
350+
"FROM endpoints "
351+
"WHERE (endpoint_metadata->'_llm'->>'hf_weights_syncing')::boolean = true"
352+
)
353+
)
354+
model_names = [row.model_name for row in result if row.model_name]
355+
except Exception:
356+
logger.warning("Could not query pending HF sync endpoints at startup", exc_info=True)
357+
return
358+
for model_name in model_names:
359+
info = SUPPORTED_MODELS_INFO.get(model_name)
360+
if info and info.hf_repo:
361+
app.state.model_weights_manager.ensure_model_weights_available(info.hf_repo)
362+
logger.info(f"Startup: re-triggered HF weight sync for {model_name}")
363+
364+
315365
def healthcheck() -> Response:
316366
"""Returns 200 if the app is healthy."""
317367
return Response(status_code=200)

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@
8686
UpdateLLMModelEndpointV1UseCase,
8787
)
8888
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
89-
from model_engine_server.domain.use_cases.model_weights_manager import ModelWeightsManager
9089
from pydantic import RootModel
9190
from sse_starlette.sse import EventSourceResponse
9291

@@ -149,14 +148,15 @@ def handle_streaming_exception(
149148
@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response)
150149
async def create_model_endpoint(
151150
wrapped_request: RootModel[CreateLLMModelEndpointV1Request],
151+
request: Request,
152152
auth: User = Depends(verify_authentication),
153153
external_interfaces: ExternalInterfaces = Depends(get_external_interfaces),
154154
) -> CreateLLMModelEndpointV1Response:
155-
request = wrapped_request.root
155+
llm_request = wrapped_request.root
156156
"""
157157
Creates an LLM endpoint for the current user.
158158
"""
159-
logger.info(f"POST /llm/model-endpoints with {request} for {auth}")
159+
logger.info(f"POST /llm/model-endpoints with {llm_request} for {auth}")
160160
try:
161161
create_model_bundle_use_case = CreateModelBundleV2UseCase(
162162
model_bundle_repository=external_interfaces.model_bundle_repository,
@@ -169,17 +169,15 @@ async def create_model_endpoint(
169169
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
170170
docker_repository=external_interfaces.docker_repository,
171171
)
172-
model_weights_manager = ModelWeightsManager(
173-
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
174-
)
172+
model_weights_manager = request.app.state.model_weights_manager
175173
use_case = CreateLLMModelEndpointV1UseCase(
176174
create_llm_model_bundle_use_case=create_llm_model_bundle_use_case,
177175
model_endpoint_service=external_interfaces.model_endpoint_service,
178176
docker_repository=external_interfaces.docker_repository,
179177
llm_artifact_gateway=external_interfaces.llm_artifact_gateway,
180178
model_weights_manager=model_weights_manager,
181179
)
182-
return await use_case.execute(user=auth, request=request)
180+
return await use_case.execute(user=auth, request=llm_request)
183181
except ObjectAlreadyExistsException as exc:
184182
raise HTTPException(
185183
status_code=400,

model-engine/model_engine_server/domain/use_cases/model_weights_manager.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import functools
33
import tempfile
4-
from typing import List
4+
from typing import Dict, List, Set
55

66
from huggingface_hub import snapshot_download
77
from model_engine_server.common.config import hmi_config
@@ -24,6 +24,8 @@
2424
class ModelWeightsManager:
2525
def __init__(self, llm_artifact_gateway: LLMArtifactGateway):
2626
self.llm_artifact_gateway = llm_artifact_gateway
27+
self._background_tasks: Set[asyncio.Task] = set()
28+
self._in_progress: Dict[str, asyncio.Task] = {}
2729

2830
def get_remote_path(self, hf_repo: str) -> str:
2931
prefix = hmi_config.hf_user_fine_tuned_weights_prefix.rstrip("/")
@@ -38,16 +40,35 @@ def ensure_model_weights_available(self, hf_repo: str) -> str:
3840
Callers receive the checkpoint path right away and can proceed with
3941
any following actions (e.g. endpoint creation) without blocking.
4042
43+
A second call for the same ``hf_repo`` while a sync is already in
44+
progress is a no-op: the existing task is reused and the same remote
45+
path is returned.
46+
4147
Args:
4248
hf_repo: HuggingFace repository ID, e.g. ``"meta-llama/Meta-Llama-3-8B"``.
4349
4450
Returns:
4551
The remote path (s3://, gs://, or https://) where the weights will be stored.
4652
"""
4753
remote_path = self.get_remote_path(hf_repo)
48-
asyncio.create_task(self._sync_weights(hf_repo, remote_path))
54+
if hf_repo not in self._in_progress:
55+
task = asyncio.create_task(self._sync_weights(hf_repo, remote_path))
56+
self._background_tasks.add(task)
57+
self._in_progress[hf_repo] = task
58+
task.add_done_callback(lambda t: self._on_task_done(t, hf_repo))
4959
return remote_path
5060

61+
def _on_task_done(self, task: asyncio.Task, hf_repo: str) -> None:
62+
self._background_tasks.discard(task)
63+
self._in_progress.pop(hf_repo, None)
64+
if not task.cancelled():
65+
exc = task.exception()
66+
if exc:
67+
logger.error(
68+
f"Background weight sync failed for {hf_repo}: {exc}",
69+
exc_info=exc,
70+
)
71+
5172
async def _sync_weights(self, hf_repo: str, remote_path: str) -> None:
5273
"""Downloads weights from HuggingFace Hub and uploads to remote storage if not cached."""
5374
files = self.llm_artifact_gateway.list_files(remote_path)

model-engine/tests/unit/domain/test_model_weights_manager.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,61 @@ def test_s3_path_construction(monkeypatch):
116116
assert path == "s3://bucket/prefix/myorg/mymodel"
117117

118118

119+
def test_deduplication_same_hf_repo():
120+
"""Second call for same hf_repo while a sync is in progress should not create a new task."""
121+
gateway = FakeArtifactGateway(existing_files=[])
122+
manager = ModelWeightsManager(llm_artifact_gateway=gateway)
123+
124+
mwm_base = "model_engine_server.domain.use_cases.model_weights_manager"
125+
with patch(f"{mwm_base}.asyncio.create_task") as mock_create_task:
126+
result1 = manager.ensure_model_weights_available("org/model")
127+
result2 = manager.ensure_model_weights_available("org/model")
128+
129+
assert mock_create_task.call_count == 1
130+
assert result1 == result2
131+
132+
133+
def test_task_reference_held_until_done():
134+
"""_background_tasks should hold a reference to the task until _on_task_done fires."""
135+
gateway = FakeArtifactGateway(existing_files=[])
136+
manager = ModelWeightsManager(llm_artifact_gateway=gateway)
137+
138+
mwm_base = "model_engine_server.domain.use_cases.model_weights_manager"
139+
mock_task = MagicMock()
140+
with patch(f"{mwm_base}.asyncio.create_task", return_value=mock_task):
141+
manager.ensure_model_weights_available("org/model")
142+
143+
assert mock_task in manager._background_tasks
144+
assert "org/model" in manager._in_progress
145+
146+
# Simulate successful task completion via the done callback
147+
mock_task.cancelled.return_value = False
148+
mock_task.exception.return_value = None
149+
manager._on_task_done(mock_task, "org/model")
150+
151+
assert mock_task not in manager._background_tasks
152+
assert "org/model" not in manager._in_progress
153+
154+
155+
def test_error_surfaced_on_task_failure():
156+
"""When the background task raises, _on_task_done should log the error."""
157+
gateway = FakeArtifactGateway(existing_files=[])
158+
manager = ModelWeightsManager(llm_artifact_gateway=gateway)
159+
160+
mock_task = MagicMock()
161+
mock_task.cancelled.return_value = False
162+
exc = RuntimeError("Download failed")
163+
mock_task.exception.return_value = exc
164+
165+
mwm_base = "model_engine_server.domain.use_cases.model_weights_manager"
166+
with patch(f"{mwm_base}.logger") as mock_logger:
167+
manager._on_task_done(mock_task, "org/model")
168+
mock_logger.error.assert_called_once()
169+
call_args = mock_logger.error.call_args
170+
assert "org/model" in call_args[0][0]
171+
assert call_args[1]["exc_info"] == exc
172+
173+
119174
@pytest.mark.asyncio
120175
async def test_create_llm_model_endpoint_calls_weights_manager_on_hf_source():
121176
"""CreateLLMModelEndpointV1UseCase should call ensure_model_weights_available (sync),

0 commit comments

Comments
 (0)