Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions cli/src/transformerlab_cli/commands/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import os
import time
import zipfile
from pathlib import Path

Expand Down Expand Up @@ -192,6 +193,95 @@
return f"{size_bytes:.1f} TB"


TERMINAL_STATUSES = {"COMPLETE", "STOPPED", "FAILED", "CANCELLED", "DELETED", "UNAUTHORIZED"}
POLL_INTERVAL = 3 # seconds


def _poll_status(job_id: str) -> str | None:
"""Quick status check via the compute provider endpoint."""
try:
response = api.get(f"/compute_provider/jobs/{job_id}/check-status", timeout=5.0)
if response.status_code == 200:
data = response.json()
return data.get("current_status")
except Exception:
pass
return None


def _poll_and_print_logs(job_id: str, experiment_id: str, lines_seen: int) -> int:
"""Fetch provider logs and print any new lines. Returns updated lines_seen."""
try:
response = api.get(
f"/experiment/{experiment_id}/jobs/{job_id}/provider_logs?tail_lines=2000",
timeout=10.0,
)
if response.status_code == 200:
data = response.json()
logs = data.get("logs", "") if isinstance(data, dict) else ""
if isinstance(logs, str):
all_lines = logs.splitlines()
elif isinstance(logs, list):
all_lines = logs
else:
return lines_seen

new_lines = all_lines[lines_seen:]
for line in new_lines:
console.print(line, highlight=False)

return len(all_lines)
except Exception:
pass
return lines_seen


def _follow_job(job_id: str, experiment_id: str) -> None:
"""Poll job status and logs until the job reaches a terminal state."""
console.print(f"\n[bold]Following job {job_id}...[/bold] (Ctrl+C to disconnect)\n")

lines_seen = 0
last_status = None

try:
while True:
status = _poll_status(job_id)
if status and status != last_status:
style = "bold green" if status == "COMPLETE" else "bold cyan"
if status in TERMINAL_STATUSES and status != "COMPLETE":
style = "bold red"
console.print(f"[{style}]Status: {status}[/{style}]")
last_status = status

lines_seen = _poll_and_print_logs(job_id, experiment_id, lines_seen)

if status in TERMINAL_STATUSES:
color = "green" if status == "COMPLETE" else "red"
console.print(f"\n[bold {color}]Job {job_id} finished: {status}[/bold {color}]")
console.print(f"[dim]View details: [bold]lab job info {job_id}[/bold][/dim]")
break

time.sleep(POLL_INTERVAL)

except KeyboardInterrupt:
console.print(f"\n[yellow]Disconnected.[/yellow] Job {job_id} continues running.")
console.print(f"[dim]Check status: [bold]lab job info {job_id}[/bold][/dim]")
console.print("[dim]View logs: [bold]lab job monitor[/bold][/dim]")


def run_task(task_id: str, experiment_id: str, interactive: bool = True, disconnect: bool = False) -> None:
"""Queue a task and optionally follow its logs until completion."""
job_id = queue_task(task_id, experiment_id=experiment_id, interactive=interactive)

if disconnect:
console.print("\n[dim]Task is running in the background.[/dim]")
console.print(f"[dim]Check status: [bold]lab job info {job_id}[/bold][/dim]")
console.print("[dim]View logs: [bold]lab job monitor[/bold][/dim]")
return

_follow_job(job_id, experiment_id)


## COMMANDS ##


Expand Down Expand Up @@ -415,8 +505,8 @@
return values


def queue_task(task_id: str, experiment_id: str, interactive: bool = True) -> None:
"""Queue a task on a compute provider."""
def queue_task(task_id: str, experiment_id: str, interactive: bool = True) -> str:
"""Queue a task on a compute provider. Returns the job ID on success."""
with console.status("[bold green]Fetching task...[/bold green]", spinner="dots"):
response = api.get(f"/experiment/{experiment_id}/task/{task_id}/get")

Expand Down Expand Up @@ -463,10 +553,14 @@
data = launch_task_on_provider(provider_id, payload)
job_id = data.get("job_id", "unknown")
console.print(f"[green]✓[/green] Task queued successfully. Job ID: [bold]{job_id}[/bold]")
return str(job_id)
except RuntimeError as e:
console.print(f"[red]Error:[/red] {e}")
raise typer.Exit(1)

# Unreachable, but keeps type checkers happy
return ""


@app.command("queue")
def command_task_queue(
Expand All @@ -478,6 +572,27 @@
queue_task(task_id, experiment_id=current_experiment, interactive=not no_interactive)


@app.command("run")
def command_task_run(
task_id: str = typer.Argument(..., help="Task ID to run"),
disconnect: bool = typer.Option(
False, "--disconnect", "-d", help="Queue the task and return immediately without following logs"
),
no_interactive: bool = typer.Option(False, "--no-interactive", help="Skip interactive prompts, use defaults"),
):
"""Run a task: queue it and follow logs until completion."""
check_configs()

Check failure on line 584 in cli/src/transformerlab_cli/commands/task.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F821)

cli/src/transformerlab_cli/commands/task.py:584:5: F821 Undefined name `check_configs`
current_experiment = get_config("current_experiment")

Check failure on line 585 in cli/src/transformerlab_cli/commands/task.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (F821)

cli/src/transformerlab_cli/commands/task.py:585:26: F821 Undefined name `get_config`
if not current_experiment or not str(current_experiment).strip():
console.print("[yellow]current_experiment is not set in config.[/yellow]")
console.print("Set it first with: [bold]lab config current_experiment <experiment_name>[/bold]")
raise typer.Exit(1)
run_task(
task_id,
experiment_id=current_experiment,
interactive=not no_interactive,
disconnect=disconnect,
)
@app.command("interactive")
def command_task_interactive(
timeout: int = typer.Option(300, "--timeout", "-t", help="Timeout in seconds waiting for service readiness"),
Expand Down
Loading