Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"soundfile==0.13.1",
"tensorboardX==2.6.2.2",
"timm==1.0.15",
"transformerlab==0.0.99",
"transformerlab==0.1.0",
"transformerlab-inference==0.2.52",
"transformers==4.57.1",
"wandb==0.23.1",
Expand Down
8 changes: 8 additions & 0 deletions api/transformerlab/routers/compute_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,9 @@ async def launch_template_on_provider(
if provider.type != ProviderType.LOCAL.value:
setup_commands.append("pip install -q transformerlab")

# Install torch as well if torch profiler is enabled
if request.enable_profiling_torch:
setup_commands.append("pip install -q torch")
# If GitHub repo fields are missing, fall back to the stored task's fields.
# This handles GitHub-sourced interactive tasks where the CLI/TUI doesn't
# send these fields and relies on the backend to resolve them from the task.
Expand Down Expand Up @@ -1778,6 +1781,11 @@ async def launch_template_on_provider(
if request.enable_trackio:
env_vars["TLAB_TRACKIO_AUTO_INIT"] = "true"

if request.enable_profiling:
env_vars["_TFL_PROFILING"] = "1"
if request.enable_profiling_torch:
env_vars["_TFL_PROFILING_TORCH"] = "1"

# Get TFL_STORAGE_URI from storage context
tfl_storage_uri = None
try:
Expand Down
29 changes: 29 additions & 0 deletions api/transformerlab/routers/experiment/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,3 +1698,32 @@ async def generate():
return StreamingResponse(
generate(), media_type=media_type, headers={"Content-Disposition": f'inline; filename="{filename}"'}
)


@router.get("/{job_id}/profiling_report")
async def get_profiling_report(
job_id: str,
experimentId: str,
session: AsyncSession = Depends(get_async_session),
user_and_team: dict = Depends(get_user_and_team),
):
"""
Return the profiling_report.json from the job's profiling folder (written when
_TFL_PROFILING=1 and copied on lab.finish/error or when the remote trap exits).

Returns 404 if profiling was not enabled or the report is not yet available.
"""
from lab.dirs import get_job_profiling_dir

profiling_dir = await get_job_profiling_dir(job_id)
report_path = storage.join(profiling_dir, "profiling_report.json")

if not await storage.exists(report_path):
raise HTTPException(status_code=404, detail="Profiling report not found for this job")

try:
async with await storage.open(report_path, "r", encoding="utf-8") as f:
content = await f.read()
return json.loads(content)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Failed to read profiling report: {exc}") from exc
8 changes: 8 additions & 0 deletions api/transformerlab/schemas/compute_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ class ProviderTemplateLaunchRequest(BaseModel):
default=False,
description="When True, set TLAB_TRACKIO_AUTO_INIT=true in the job environment so lab SDK can auto-integrate with Trackio.",
)
enable_profiling: Optional[bool] = Field(
default=False,
description="When True, set _TFL_PROFILING=1 to enable system-level CPU/GPU/memory sampling via tfl-remote-trap.",
)
enable_profiling_torch: Optional[bool] = Field(
default=False,
description="When True (requires enable_profiling), also set _TFL_PROFILING_TORCH=1 to inject torch.profiler and export a Chrome trace.",
)


class ProviderTemplateFileUploadResponse(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion lab-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "transformerlab"
version = "0.0.99"
version = "0.1.0"
description = "Python SDK for Transformer Lab"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
11 changes: 11 additions & 0 deletions lab-sdk/src/lab/dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,17 @@ async def get_job_artifacts_dir(job_id: str | int) -> str:
return path


async def get_job_profiling_dir(job_id: str | int) -> str:
"""
Return the profiling directory for a specific job, creating it if needed.
Example: ~/.transformerlab/workspace/jobs/<job_id>/profiling
"""
job_dir = await get_job_dir(job_id)
path = storage.join(job_dir, "profiling")
await storage.makedirs(path, exist_ok=True)
return path


async def get_job_checkpoints_dir(job_id: str | int) -> str:
"""
Return the checkpoints directory for a specific job, creating it if needed.
Expand Down
6 changes: 6 additions & 0 deletions lab-sdk/src/lab/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ async def get_artifacts_dir(self):
"""
return await dirs.get_job_artifacts_dir(self.id)

async def get_profiling_dir(self):
"""
Get the profiling directory path for this job.
"""
return await dirs.get_job_profiling_dir(self.id)

async def get_checkpoint_paths(self):
"""
Get list of checkpoint paths for this job.
Expand Down
18 changes: 18 additions & 0 deletions lab-sdk/src/lab/lab_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,15 @@ def finish(
Mark the job as successfully completed and set completion metadata.
"""
self._ensure_initialized()
# Copy profiling from temp dir into job's profiling folder (when run under remote trap).
try:
profiling_temp = os.environ.get("_TFL_PROFILING_TEMP_DIR")
if profiling_temp and self._job:
from lab.profiling import copy_profiling_to_job

_run_async(copy_profiling_to_job(profiling_temp, str(self._job.id))) # type: ignore[union-attr]
except Exception:
pass
_run_async(self._job.update_progress(100)) # type: ignore[union-attr]
_run_async(self._job.update_status(JobStatus.COMPLETE)) # type: ignore[union-attr]
_run_async(self._job.update_job_data_field("completion_status", "success")) # type: ignore[union-attr]
Expand Down Expand Up @@ -1435,6 +1444,15 @@ def error(
Mark the job as failed and set completion metadata.
"""
self._ensure_initialized()
# Copy profiling from temp dir into job's profiling folder (when run under remote trap).
try:
profiling_temp = os.environ.get("_TFL_PROFILING_TEMP_DIR")
if profiling_temp and self._job:
from lab.profiling import copy_profiling_to_job

_run_async(copy_profiling_to_job(profiling_temp, str(self._job.id))) # type: ignore[union-attr]
except Exception:
pass
_run_async(self._job.update_status(JobStatus.COMPLETE)) # type: ignore[union-attr]
_run_async(self._job.update_job_data_field("completion_status", "failed")) # type: ignore[union-attr]
_run_async(self._job.update_job_data_field("completion_details", message)) # type: ignore[union-attr]
Expand Down
Loading
Loading