Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 18 additions & 57 deletions api/transformerlab/routers/compute_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from transformerlab.services import quota_service
from transformerlab.services.task_service import task_service
from transformerlab.services.local_provider_queue import enqueue_local_launch
from transformerlab.services.remote_provider_queue import enqueue_remote_launch

from lab import storage
from lab.storage import STORAGE_PROVIDER
Expand Down Expand Up @@ -1618,8 +1619,7 @@ async def launch_template_on_provider(
if not has_quota:
raise HTTPException(status_code=403, detail=message)

# Get provider instance (resolves user's slurm_user for SLURM when user_id/team_id set)
provider_instance = await get_provider_instance(provider, user_id=user_id, team_id=team_id)
# NOTE: We no longer launch inline; provider instance is resolved in the remote launch worker.

# Interactive templates should start directly in INTERACTIVE state instead of LAUNCHING,
# except for LOCAL providers where we introduce a WAITING status while queued.
Expand Down Expand Up @@ -1656,6 +1656,8 @@ async def launch_template_on_provider(
minutes_requested=request.minutes_requested,
job_id=str(job_id),
)
# We return immediately after enqueuing remote launches, so persist the hold now.
await session.commit()

await job_service.job_update_launch_progress(
job_id,
Expand Down Expand Up @@ -2126,66 +2128,24 @@ async def launch_template_on_provider(
"message": "Local provider launch waiting in queue",
}

try:
launch_result = await asyncio.to_thread(
provider_instance.launch_cluster, formatted_cluster_name, cluster_config
)
except Exception as exc:
print(f"Failed to launch cluster: {exc}")
await job_service.job_update_launch_progress(
job_id,
request.experiment_id,
phase="failed",
percent=100,
message=f"Launch failed: {exc!s}",
)
# Release quota hold if launch failed
if quota_hold:
await quota_service.release_quota_hold(session, hold_id=quota_hold.id)
await session.commit()
await job_service.job_update_status(
job_id,
JobStatus.FAILED,
request.experiment_id,
error_msg=str(exc),
)
raise HTTPException(status_code=500, detail="Failed to launch cluster") from exc

await job_service.job_update_launch_progress(
job_id,
request.experiment_id,
phase="cluster_started",
percent=100,
message="Launch initiated",
await enqueue_remote_launch(
job_id=str(job_id),
experiment_id=str(request.experiment_id),
provider_id=str(provider.id),
team_id=str(team_id),
user_id=str(user.id),
cluster_name=formatted_cluster_name,
cluster_config=cluster_config,
quota_hold_id=str(quota_hold.id) if quota_hold else None,
subtype=request.subtype,
)

# Commit quota hold creation after successful launch
if quota_hold:
await session.commit()

request_id = None
if isinstance(launch_result, dict):
await job_service.job_update_job_data_insert_key_value(
job_id,
"provider_launch_result",
launch_result,
request.experiment_id,
)
request_id = launch_result.get("request_id")
if request_id:
await job_service.job_update_job_data_insert_key_value(
job_id,
"orchestrator_request_id",
request_id,
request.experiment_id,
)

return {
"status": "success",
"job_id": job_id,
"cluster_name": formatted_cluster_name,
"request_id": request_id,
"message": "Provider launch initiated",
"request_id": None,
"message": "Provider launch enqueued",
}


Expand Down Expand Up @@ -2741,7 +2701,8 @@ async def stop_cluster(

# Return the result directly from the provider
return result
except Exception:
except Exception as e:
print(f"Failed to stop cluster: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to stop cluster")


Expand Down
189 changes: 189 additions & 0 deletions api/transformerlab/services/remote_provider_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import asyncio
import logging
import os
from typing import Optional

from pydantic import BaseModel

from transformerlab.compute_providers.models import ClusterConfig
from transformerlab.db.session import async_session
from transformerlab.services import job_service, quota_service
from transformerlab.services.provider_service import get_provider_by_id, get_provider_instance
from lab import dirs as lab_dirs
from lab.job_status import JobStatus

logger = logging.getLogger(__name__)


class RemoteLaunchWorkItem(BaseModel):
"""Work item for launching a non-local provider job in the background."""

job_id: str
experiment_id: str
provider_id: str
team_id: str
user_id: str
cluster_name: str
cluster_config: ClusterConfig
quota_hold_id: Optional[str] = None
subtype: Optional[str] = None # e.g. "interactive"


_remote_launch_queue: "asyncio.Queue[RemoteLaunchWorkItem]" = asyncio.Queue()
_dispatcher_task: Optional[asyncio.Task] = None
_dispatcher_lock = asyncio.Lock()

# Concurrency: remote launches should start immediately, but we still cap total parallelism
try:
_MAX_CONCURRENT_REMOTE_LAUNCHES = int(os.getenv("TFL_MAX_CONCURRENT_REMOTE_LAUNCHES", "8"))
except Exception: # noqa: BLE001
_MAX_CONCURRENT_REMOTE_LAUNCHES = 8

_remote_launch_semaphore = asyncio.Semaphore(_MAX_CONCURRENT_REMOTE_LAUNCHES)


async def enqueue_remote_launch(
job_id: str,
experiment_id: str,
provider_id: str,
team_id: str,
user_id: str,
cluster_name: str,
cluster_config: ClusterConfig,
quota_hold_id: Optional[str],
subtype: Optional[str],
) -> None:
"""Enqueue a remote provider launch work item and ensure the dispatcher is running."""
global _dispatcher_task
item = RemoteLaunchWorkItem(
job_id=str(job_id),
experiment_id=str(experiment_id),
provider_id=str(provider_id),
team_id=str(team_id),
user_id=str(user_id),
cluster_name=cluster_name,
cluster_config=cluster_config,
quota_hold_id=quota_hold_id,
subtype=subtype,
)
await _remote_launch_queue.put(item)

async with _dispatcher_lock:
if _dispatcher_task is None or _dispatcher_task.done():
_dispatcher_task = asyncio.create_task(_dispatcher_loop())


async def _dispatcher_loop() -> None:
"""Continuously dispatch queued launches into concurrent worker tasks."""
while True:
item = await _remote_launch_queue.get()
task = asyncio.create_task(_process_launch_item(item))
task.add_done_callback(_log_task_exception)


def _log_task_exception(task: asyncio.Task) -> None:
try:
exc = task.exception()
except asyncio.CancelledError:
return
except Exception: # noqa: BLE001
logger.exception("Remote launch task failed while retrieving exception")
return

if exc is not None:
logger.exception("Remote launch task crashed", exc_info=exc)


async def _process_launch_item(item: RemoteLaunchWorkItem) -> None:
"""Process a single remote launch work item."""
async with _remote_launch_semaphore:
async with async_session() as session:
lab_dirs.set_organization_id(item.team_id)
try:
await job_service.job_update_launch_progress(
item.job_id,
item.experiment_id,
phase="launching_cluster",
percent=70,
message="Launching cluster",
)

provider = await get_provider_by_id(session, item.provider_id)
if not provider:
await job_service.job_update_status(
item.job_id,
JobStatus.FAILED,
experiment_id=item.experiment_id,
error_msg="Provider not found for remote launch",
session=session,
)
if item.quota_hold_id:
await quota_service.release_quota_hold(session, hold_id=item.quota_hold_id)
await session.commit()
return

provider_instance = await get_provider_instance(provider, user_id=item.user_id, team_id=item.team_id)

loop = asyncio.get_running_loop()

def _launch_with_org_context():
lab_dirs.set_organization_id(item.team_id)
return provider_instance.launch_cluster(item.cluster_name, item.cluster_config)

try:
launch_result = await loop.run_in_executor(None, _launch_with_org_context)
except Exception as exc: # noqa: BLE001
await job_service.job_update_launch_progress(
item.job_id,
item.experiment_id,
phase="failed",
percent=100,
message=f"Launch failed: {exc!s}",
)
if item.quota_hold_id:
await quota_service.release_quota_hold(session, hold_id=item.quota_hold_id)
await job_service.job_update_status(
item.job_id,
JobStatus.FAILED,
experiment_id=item.experiment_id,
error_msg=str(exc),
session=session,
)
await session.commit()
return

await job_service.job_update_launch_progress(
item.job_id,
item.experiment_id,
phase="cluster_started",
percent=100,
message="Launch initiated",
)

if isinstance(launch_result, dict):
await job_service.job_update_job_data_insert_key_value(
item.job_id,
"provider_launch_result",
launch_result,
item.experiment_id,
)
request_id = launch_result.get("request_id")
if request_id:
await job_service.job_update_job_data_insert_key_value(
item.job_id,
"orchestrator_request_id",
request_id,
item.experiment_id,
)

# Keep the job in LAUNCHING/INTERACTIVE; status polling will advance it later.
next_status = JobStatus.INTERACTIVE if item.subtype == "interactive" else JobStatus.LAUNCHING
await job_service.job_update_status(
item.job_id,
next_status,
experiment_id=item.experiment_id,
session=session,
)
await session.commit()
finally:
lab_dirs.set_organization_id(None)
Loading