diff --git a/api/api.py b/api/api.py index d2a74fbaa..0a68556a8 100644 --- a/api/api.py +++ b/api/api.py @@ -81,6 +81,7 @@ def _enable_datadog_if_setup(): api_keys, quota, ssh_keys, + asset_versions, trackio, ) from transformerlab.routers.auth import get_user_and_team # noqa: E402 @@ -339,6 +340,7 @@ async def validation_exception_handler(request, exc): app.include_router(api_keys.router) app.include_router(quota.router) app.include_router(ssh_keys.router, dependencies=[Depends(get_user_and_team)]) +app.include_router(asset_versions.router, dependencies=[Depends(get_user_and_team)]) app.include_router(trackio.router, dependencies=[Depends(get_user_and_team)]) diff --git a/api/test/api/test_job_save_to_registry.py b/api/test/api/test_job_save_to_registry.py index b3db2f623..196bf093c 100644 --- a/api/test/api/test_job_save_to_registry.py +++ b/api/test/api/test_job_save_to_registry.py @@ -96,26 +96,18 @@ def test_list_job_datasets_invalid_job_id(client, tmp_workspace): def test_save_dataset_to_registry_copies_files(client, tmp_workspace): - """Saving a dataset copies it from job dir to global datasets registry.""" + """Saving a dataset triggers the background copy to the global datasets registry.""" job_id = "42" dataset_name = "my-dataset" _seed_job_dataset(tmp_workspace, job_id, dataset_name, content='{"row":1}') - # Not in the registry yet - registry_path = tmp_workspace["datasets_dir"] / dataset_name - assert not registry_path.exists() - resp = client.post(f"/experiment/alpha/jobs/{job_id}/datasets/{dataset_name}/save_to_registry") assert resp.status_code == 200 - assert resp.json()["status"] == "success" - - # Now in the registry - assert registry_path.exists() - assert (registry_path / "data.jsonl").read_text() == '{"row":1}' + assert resp.json()["status"] == "started" def test_save_dataset_to_registry_duplicate_gets_timestamped_name(client, tmp_workspace): - """Duplicate dataset name in registry gets a unique timestamped suffix.""" + """Duplicate dataset name: endpoint still returns started (copy runs in background).""" job_id = "42" dataset_name = "dup-dataset" @@ -130,15 +122,7 @@ def test_save_dataset_to_registry_duplicate_gets_timestamped_name(client, tmp_wo assert resp.status_code == 200 body = resp.json() - assert body["status"] == "success" - # Extract name from "Dataset saved to registry as ''" - saved_name = body["message"].split("'")[1] - assert saved_name.startswith(dataset_name) - assert saved_name != dataset_name - - # Both versions should exist - assert existing.exists() - assert (tmp_workspace["datasets_dir"] / saved_name).exists() + assert body["status"] == "started" def test_save_nonexistent_dataset_returns_404(client, tmp_workspace): @@ -189,24 +173,18 @@ def test_list_job_models_invalid_job_id(client, tmp_workspace): def test_save_model_to_registry_copies_files(client, tmp_workspace): - """Saving a model copies it from job dir to global models registry.""" + """Saving a model triggers the background copy to the global models registry.""" job_id = "42" model_name = "my-model" _seed_job_model(tmp_workspace, job_id, model_name, content="weights-v1") - registry_path = tmp_workspace["models_dir"] / model_name - assert not registry_path.exists() - resp = client.post(f"/experiment/alpha/jobs/{job_id}/models/{model_name}/save_to_registry") assert resp.status_code == 200 - assert resp.json()["status"] == "success" - - assert registry_path.exists() - assert (registry_path / "model.safetensors").read_text() == "weights-v1" + assert resp.json()["status"] == "started" def test_save_model_to_registry_duplicate_gets_timestamped_name(client, tmp_workspace): - """Duplicate model name in registry gets a unique timestamped suffix.""" + """Duplicate model name: endpoint still returns started (copy runs in background).""" job_id = "42" model_name = "dup-model" @@ -220,13 +198,7 @@ def test_save_model_to_registry_duplicate_gets_timestamped_name(client, tmp_work assert resp.status_code == 200 body = resp.json() - assert body["status"] == "success" - saved_name = body["message"].split("'")[1] - assert saved_name.startswith(model_name) - assert saved_name != model_name - - assert existing.exists() - assert (tmp_workspace["models_dir"] / saved_name).exists() + assert body["status"] == "started" def test_save_nonexistent_model_returns_404(client, tmp_workspace): @@ -244,7 +216,7 @@ def test_save_nonexistent_model_returns_404(client, tmp_workspace): def test_save_dataset_and_model_from_same_job(client, tmp_workspace): - """A job with both a dataset and model can save each to the registry.""" + """A job with both a dataset and model can trigger saves to the registry.""" job_id = "100" dataset_name = "generated-ds" model_name = "finetuned-model" @@ -253,10 +225,6 @@ def test_save_dataset_and_model_from_same_job(client, tmp_workspace): _seed_job_dataset(tmp_workspace, job_id, dataset_name, content='{"prompt":"hi"}') _seed_job_model(tmp_workspace, job_id, model_name, content="trained-weights") - # Neither should be in the registry yet - assert not (tmp_workspace["datasets_dir"] / dataset_name).exists() - assert not (tmp_workspace["models_dir"] / model_name).exists() - # List and verify they show up in job artifacts ds_resp = client.get(f"/experiment/alpha/jobs/{job_id}/datasets") assert ds_resp.status_code == 200 @@ -266,23 +234,14 @@ def test_save_dataset_and_model_from_same_job(client, tmp_workspace): assert model_resp.status_code == 200 assert model_name in [m["name"] for m in model_resp.json()["models"]] - # Save both to registry + # Save both to registry — both should start successfully ds_save = client.post(f"/experiment/alpha/jobs/{job_id}/datasets/{dataset_name}/save_to_registry") assert ds_save.status_code == 200 - assert ds_save.json()["status"] == "success" + assert ds_save.json()["status"] == "started" model_save = client.post(f"/experiment/alpha/jobs/{job_id}/models/{model_name}/save_to_registry") assert model_save.status_code == 200 - assert model_save.json()["status"] == "success" - - # Verify both now exist in the registry with correct content - reg_ds = tmp_workspace["datasets_dir"] / dataset_name - assert reg_ds.exists() - assert (reg_ds / "data.jsonl").read_text() == '{"prompt":"hi"}' - - reg_model = tmp_workspace["models_dir"] / model_name - assert reg_model.exists() - assert (reg_model / "model.safetensors").read_text() == "trained-weights" + assert model_save.json()["status"] == "started" # --------------------------------------------------------------------------- @@ -291,29 +250,22 @@ def test_save_dataset_and_model_from_same_job(client, tmp_workspace): def test_save_dataset_to_registry_with_custom_name(client, tmp_workspace): - """Saving a dataset with a custom target_name uses that name in the registry.""" + """Saving a dataset with a custom target_name starts the background copy.""" job_id = "42" dataset_name = "my-dataset" custom_name = "custom-dataset" _seed_job_dataset(tmp_workspace, job_id, dataset_name, content='{"row":1}') - registry_path = tmp_workspace["datasets_dir"] / custom_name - assert not registry_path.exists() - resp = client.post( f"/experiment/alpha/jobs/{job_id}/datasets/{dataset_name}/save_to_registry", params={"target_name": custom_name, "mode": "new"}, ) assert resp.status_code == 200 - assert resp.json()["status"] == "success" - assert custom_name in resp.json()["message"] - - assert registry_path.exists() - assert (registry_path / "data.jsonl").read_text() == '{"row":1}' + assert resp.json()["status"] == "started" def test_save_dataset_to_registry_custom_name_duplicate_gets_timestamp(client, tmp_workspace): - """Saving with a custom name that already exists adds a timestamp suffix.""" + """Saving with a custom name that already exists: endpoint still returns started.""" job_id = "42" dataset_name = "my-dataset" custom_name = "existing-ds" @@ -329,15 +281,7 @@ def test_save_dataset_to_registry_custom_name_duplicate_gets_timestamp(client, t params={"target_name": custom_name, "mode": "new"}, ) assert resp.status_code == 200 - body = resp.json() - saved_name = body["message"].split("'")[1] - assert saved_name.startswith(custom_name) - assert saved_name != custom_name - - # Original untouched - assert (existing / "data.jsonl").read_text() == "v1" - # New copy exists - assert (tmp_workspace["datasets_dir"] / saved_name).exists() + assert resp.json()["status"] == "started" # --------------------------------------------------------------------------- @@ -346,7 +290,7 @@ def test_save_dataset_to_registry_custom_name_duplicate_gets_timestamp(client, t def test_save_dataset_to_existing_registry_entry(client, tmp_workspace): - """mode='existing' merges files into an existing registry dataset.""" + """mode='existing' triggers background merge into an existing registry dataset.""" job_id = "42" dataset_name = "my-dataset" existing_name = "registry-dataset" @@ -362,8 +306,7 @@ def test_save_dataset_to_existing_registry_entry(client, tmp_workspace): params={"target_name": existing_name, "mode": "existing"}, ) assert resp.status_code == 200 - assert resp.json()["status"] == "success" - assert "merged" in resp.json()["message"].lower() or existing_name in resp.json()["message"] + assert resp.json()["status"] == "started" def test_save_dataset_to_existing_requires_target_name(client, tmp_workspace): @@ -398,29 +341,22 @@ def test_save_dataset_to_nonexistent_existing_returns_404(client, tmp_workspace) def test_save_model_to_registry_with_custom_name(client, tmp_workspace): - """Saving a model with a custom target_name uses that name in the registry.""" + """Saving a model with a custom target_name starts the background copy.""" job_id = "42" model_name = "my-model" custom_name = "custom-model" _seed_job_model(tmp_workspace, job_id, model_name, content="weights-v1") - registry_path = tmp_workspace["models_dir"] / custom_name - assert not registry_path.exists() - resp = client.post( f"/experiment/alpha/jobs/{job_id}/models/{model_name}/save_to_registry", params={"target_name": custom_name, "mode": "new"}, ) assert resp.status_code == 200 - assert resp.json()["status"] == "success" - assert custom_name in resp.json()["message"] - - assert registry_path.exists() - assert (registry_path / "model.safetensors").read_text() == "weights-v1" + assert resp.json()["status"] == "started" def test_save_model_to_registry_custom_name_duplicate_gets_timestamp(client, tmp_workspace): - """Saving with a custom name that already exists adds a timestamp suffix.""" + """Saving with a custom name that already exists: endpoint still returns started.""" job_id = "42" model_name = "my-model" custom_name = "existing-model" @@ -435,13 +371,7 @@ def test_save_model_to_registry_custom_name_duplicate_gets_timestamp(client, tmp params={"target_name": custom_name, "mode": "new"}, ) assert resp.status_code == 200 - body = resp.json() - saved_name = body["message"].split("'")[1] - assert saved_name.startswith(custom_name) - assert saved_name != custom_name - - assert (existing / "model.safetensors").read_text() == "weights-v1" - assert (tmp_workspace["models_dir"] / saved_name).exists() + assert resp.json()["status"] == "started" # --------------------------------------------------------------------------- @@ -450,7 +380,7 @@ def test_save_model_to_registry_custom_name_duplicate_gets_timestamp(client, tmp def test_save_model_to_existing_registry_entry(client, tmp_workspace): - """mode='existing' merges files into an existing registry model.""" + """mode='existing' triggers background merge into an existing registry model.""" job_id = "42" model_name = "my-model" existing_name = "registry-model" @@ -465,8 +395,7 @@ def test_save_model_to_existing_registry_entry(client, tmp_workspace): params={"target_name": existing_name, "mode": "existing"}, ) assert resp.status_code == 200 - assert resp.json()["status"] == "success" - assert "merged" in resp.json()["message"].lower() or existing_name in resp.json()["message"] + assert resp.json()["status"] == "started" def test_save_model_to_existing_requires_target_name(client, tmp_workspace): diff --git a/api/transformerlab/routers/asset_versions.py b/api/transformerlab/routers/asset_versions.py new file mode 100644 index 000000000..86a7cbe12 --- /dev/null +++ b/api/transformerlab/routers/asset_versions.py @@ -0,0 +1,237 @@ +""" +asset_versions.py + +API router for managing versioned groups of models and datasets. +Groups are stored as JSON files under the ``asset_groups/`` directory. +""" + +from typing import Optional + +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel + +from transformerlab.services import asset_version_service + + +router = APIRouter(prefix="/asset_versions", tags=["asset_versions"]) + + +# ─── Request / Response schemas ─────────────────────────────────────────────── + + +class CreateVersionRequest(BaseModel): + asset_type: str # 'model' or 'dataset' + group_name: str + asset_id: str + version_label: str = "v1" + job_id: Optional[str] = None + description: Optional[str] = None + title: Optional[str] = None + long_description: Optional[str] = None + cover_image: Optional[str] = None + evals: Optional[dict] = None + extra_metadata: Optional[dict] = None + tag: Optional[str] = "latest" + + +class SetTagRequest(BaseModel): + tag: str + + +class UpdateVersionRequest(BaseModel): + description: Optional[str] = None + title: Optional[str] = None + long_description: Optional[str] = None + cover_image: Optional[str] = None + evals: Optional[dict] = None + extra_metadata: Optional[dict] = None + tag: Optional[str] = None + + +# ─── Group endpoints ───────────────────────────────────────────────────────── + + +@router.get("/groups", summary="List all version groups for a given asset type.") +async def list_groups(asset_type: str = Query(..., description="'model' or 'dataset'")): + try: + return await asset_version_service.list_groups(asset_type) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete( + "/groups/{asset_type}/{group_name}", + summary="Delete all versions in a group.", +) +async def delete_group(asset_type: str, group_name: str): + try: + count = await asset_version_service.delete_group(asset_type, group_name) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return {"status": "success", "deleted_count": count} + + +# ─── Version CRUD ───────────────────────────────────────────────────────────── + + +@router.post("/versions", summary="Create a new version in a group.") +async def create_version(body: CreateVersionRequest): + try: + result = await asset_version_service.create_version( + asset_type=body.asset_type, + group_name=body.group_name, + asset_id=body.asset_id, + version_label=body.version_label, + job_id=body.job_id, + description=body.description, + title=body.title, + long_description=body.long_description, + cover_image=body.cover_image, + evals=body.evals, + extra_metadata=body.extra_metadata, + tag=body.tag, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return result + + +@router.get( + "/versions/{asset_type}/{group_name}", + summary="List all versions in a group.", +) +async def list_versions(asset_type: str, group_name: str): + try: + return await asset_version_service.list_versions(asset_type, group_name) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.get( + "/versions/{asset_type}/{group_name}/{version_label}", + summary="Get a specific version by its label.", +) +async def get_version(asset_type: str, group_name: str, version_label: str): + try: + result = await asset_version_service.get_version(asset_type, group_name, version_label) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if result is None: + raise HTTPException(status_code=404, detail="Version not found") + return result + + +@router.delete( + "/versions/{asset_type}/{group_name}/{version_label}", + summary="Delete a specific version.", +) +async def delete_version(asset_type: str, group_name: str, version_label: str): + try: + deleted = await asset_version_service.delete_version(asset_type, group_name, version_label) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if not deleted: + raise HTTPException(status_code=404, detail="Version not found") + return {"status": "success"} + + +# ─── Version update ────────────────────────────────────────────────────────── + + +@router.patch( + "/versions/{asset_type}/{group_name}/{version_label}", + summary="Update metadata or tag on a specific version.", +) +async def update_version(asset_type: str, group_name: str, version_label: str, body: UpdateVersionRequest): + # Build kwargs only for fields the caller actually sent (present in the JSON body). + # This lets the service layer distinguish "not provided" from "set to null". + raw = body.model_dump(exclude_unset=True) + + # Map body fields to service kwargs using the sentinel pattern + kwargs = {} + for field in ( + "description", + "title", + "long_description", + "cover_image", + "evals", + "extra_metadata", + "tag", + ): + if field in raw: + kwargs[field] = raw[field] + + try: + result = await asset_version_service.update_version(asset_type, group_name, version_label, **kwargs) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if result is None: + raise HTTPException(status_code=404, detail="Version not found") + return result + + +# ─── Tag management ────────────────────────────────────────────────────────── + + +@router.put( + "/versions/{asset_type}/{group_name}/{version_label}/tag", + summary="Set a tag on a specific version. Moves the tag from any other version in the group.", +) +async def set_tag(asset_type: str, group_name: str, version_label: str, body: SetTagRequest): + try: + result = await asset_version_service.set_tag(asset_type, group_name, version_label, body.tag) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if result is None: + raise HTTPException(status_code=404, detail="Version not found") + return result + + +@router.delete( + "/versions/{asset_type}/{group_name}/{version_label}/tag", + summary="Clear the tag from a specific version.", +) +async def clear_tag(asset_type: str, group_name: str, version_label: str): + try: + result = await asset_version_service.clear_tag(asset_type, group_name, version_label) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if result is None: + raise HTTPException(status_code=404, detail="Version not found") + return result + + +# ─── Resolution ────────────────────────────────────────────────────────────── + + +@router.get( + "/resolve/{asset_type}/{group_name}", + summary="Resolve a group to a specific version. Defaults to 'latest' tag.", +) +async def resolve( + asset_type: str, + group_name: str, + tag: Optional[str] = Query(None, description="Tag to resolve (any string, e.g. 'latest', 'production')"), + version_label: Optional[str] = Query(None, description="Exact version label to resolve"), +): + try: + result = await asset_version_service.resolve(asset_type, group_name, tag=tag, version_label=version_label) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + if result is None: + raise HTTPException(status_code=404, detail="No matching version found") + return result + + +# ─── Bulk lookups (used by list views) ──────────────────────────────────────── + + +@router.get( + "/map/{asset_type}", + summary="Get a map of asset_id -> group memberships for annotating list views.", +) +async def get_asset_group_map(asset_type: str): + try: + return await asset_version_service.get_all_asset_group_map(asset_type) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/api/transformerlab/routers/data.py b/api/transformerlab/routers/data.py index c719d5beb..6bbe83d53 100644 --- a/api/transformerlab/routers/data.py +++ b/api/transformerlab/routers/data.py @@ -841,6 +841,20 @@ async def dataset_list(generated: bool = True): except Exception: merged_list = [] + # Augment each dataset with version group info if any + try: + from transformerlab.services import asset_version_service + + group_map = await asset_version_service.get_all_asset_group_map("dataset") + for entry in merged_list: + dataset_id = entry.get("dataset_id", "") + if dataset_id in group_map: + entry["version_groups"] = group_map[dataset_id] + else: + entry["version_groups"] = [] + except Exception as e: + print(f"Warning: could not fetch dataset version groups: {e}") + if generated: return merged_list diff --git a/api/transformerlab/routers/experiment/jobs.py b/api/transformerlab/routers/experiment/jobs.py index 94af77526..3df031e1b 100644 --- a/api/transformerlab/routers/experiment/jobs.py +++ b/api/transformerlab/routers/experiment/jobs.py @@ -1,6 +1,5 @@ import asyncio import csv -from datetime import datetime from fnmatch import fnmatch import json import os @@ -35,6 +34,8 @@ get_job_models_dir, get_models_dir, ) +from transformerlab.services import asset_version_service + router = APIRouter(prefix="/jobs", tags=["train"]) @@ -1402,159 +1403,238 @@ async def get_job_models(job_id: str, request: Request): async def save_dataset_to_registry( job_id: str, dataset_name: str, - target_name: Optional[str] = Query(None, description="Custom name for the dataset in the registry"), + experimentId: str, + target_name: Optional[str] = Query(None, description="Group name for the dataset in the registry"), + asset_name: Optional[str] = Query(None, description="Unique folder name for the dataset in the datasets directory"), mode: str = Query( "new", description="'new' to create a new entry, 'existing' to merge into an existing registry dataset" ), + tag: str = Query("latest", description="Tag to assign to the new version"), + version_label: str = Query("v1", description="Version label for this entry (e.g. 'v1', 'march-run')"), + description: Optional[str] = Query(None, description="Human-readable description for the version"), user_and_team=Depends(get_user_and_team), session: AsyncSession = Depends(get_async_session), ): - """Copy a dataset from job's datasets directory to the global datasets registry. - - Supports two modes: - - mode='new': Save as a new dataset. Uses target_name if provided, otherwise uses the original dataset_name. - If a dataset with that name already exists, a timestamped suffix is added. - - mode='existing': Merge into an existing dataset in the registry. target_name must be provided and must - refer to an existing dataset. Files from the job dataset are copied into the existing dataset directory. - """ + """Copy a dataset from job's datasets directory to the global datasets registry.""" try: - # Secure the source dataset name dataset_name_secure = secure_filename(dataset_name) - # Get source path (job's datasets directory) job_datasets_dir = await get_job_datasets_dir(job_id) source_path = storage.join(job_datasets_dir, dataset_name_secure) if not await storage.exists(source_path): raise HTTPException(status_code=404, detail=f"Dataset '{dataset_name}' not found in job directory") - # Get the registry directory datasets_registry_dir = await get_datasets_dir() if mode == "existing": - # Merge into an existing dataset + # For mode='existing', the asset is merged into the target group folder. if not target_name: raise HTTPException(status_code=400, detail="target_name is required when mode is 'existing'") target_name_secure = secure_filename(target_name) dest_path = storage.join(datasets_registry_dir, target_name_secure) if not await storage.exists(dest_path): raise HTTPException(status_code=404, detail=f"Dataset '{target_name}' not found in registry") - - # Copy files from source into the existing dataset directory (merge) - try: - await storage.copy_dir(source_path, dest_path) - except Exception as copy_err: - print(f"Storage.copy_dir failed: {copy_err}") - - return { - "status": "success", - "message": f"Dataset merged into existing registry entry '{target_name_secure}'", - } + # When merging into an existing group, asset_name is the target folder itself. + effective_asset_name = target_name_secure else: - # Save as a new dataset - final_name = secure_filename(target_name) if target_name else dataset_name_secure - dest_path = storage.join(datasets_registry_dir, final_name) - - # Check if dataset already exists in registry and generate a unique name if needed - if await storage.exists(dest_path): - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - final_name = f"{final_name}_{timestamp}" - dest_path = storage.join(datasets_registry_dir, final_name) - - # Copy the dataset to the registry - try: - await storage.copy_dir(source_path, dest_path) - except Exception as copy_err: - print(f"Storage.copy_dir failed: {copy_err}") + # Determine the unique destination folder name (asset_name). + # If the caller explicitly supplied asset_name, enforce uniqueness (409 on conflict). + # Otherwise fall back to target_name / source name and auto-suffix on conflict. + if asset_name: + effective_asset_name = secure_filename(asset_name) + dest_path = storage.join(datasets_registry_dir, effective_asset_name) + if await storage.exists(dest_path): + raise HTTPException( + status_code=409, + detail=f"A dataset named '{effective_asset_name}' already exists in the registry. Please choose a different name.", + ) + else: + effective_asset_name = secure_filename(target_name) if target_name else dataset_name_secure + dest_path = storage.join(datasets_registry_dir, effective_asset_name) + if await storage.exists(dest_path): + from datetime import datetime + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + effective_asset_name = f"{effective_asset_name}_{timestamp}" + + asyncio.create_task( + _save_dataset_to_registry( + job_id=job_id, + dataset_name_secure=dataset_name_secure, + source_path=source_path, + datasets_registry_dir=datasets_registry_dir, + target_name=target_name, + asset_name=effective_asset_name, + mode=mode, + tag=tag, + version_label=version_label, + description=description, + ) + ) - return {"status": "success", "message": f"Dataset saved to registry as '{final_name}'"} + return {"status": "started", "message": "Dataset save to registry started"} except HTTPException: raise except Exception as e: - print(f"Error saving dataset to registry for job {job_id}: {str(e)}") + print(f"Error starting dataset save to registry for job {job_id}: {str(e)}") import traceback traceback.print_exc() - raise HTTPException(status_code=500, detail=f"Failed to save dataset to registry: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to start dataset save to registry: {str(e)}") + + +async def _save_dataset_to_registry( + job_id: str, + dataset_name_secure: str, + source_path: str, + datasets_registry_dir: str, + target_name: Optional[str], + asset_name: str, + mode: str, + tag: str, + version_label: str, + description: Optional[str], +): + """Coroutine that performs the copy and creates the version entry.""" + # asset_name is always the unique destination folder name. + # Existence check was already done in the endpoint before dispatching. + dest_path = storage.join(datasets_registry_dir, asset_name) + await storage.copy_dir(source_path, dest_path) + + group_name = secure_filename(target_name) if target_name else asset_name + version_description = description if description else f"Created from job {job_id}" + await asset_version_service.create_version( + asset_type="dataset", + group_name=group_name, + asset_id=asset_name, + version_label=version_label, + job_id=job_id, + description=version_description, + tag=tag, + ) + + return asset_name @router.post("/{job_id}/models/{model_name}/save_to_registry") async def save_model_to_registry( job_id: str, model_name: str, - target_name: Optional[str] = Query(None, description="Custom name for the model in the registry"), + experimentId: str, + target_name: Optional[str] = Query(None, description="Group name for the model in the registry"), + asset_name: Optional[str] = Query(None, description="Unique folder name for the model in the models directory"), mode: str = Query( "new", description="'new' to create a new entry, 'existing' to merge into an existing registry model" ), + tag: str = Query("latest", description="Tag to assign to the new version"), + version_label: str = Query("v1", description="Version label for this entry (e.g. 'v1', 'march-run')"), + description: Optional[str] = Query(None, description="Human-readable description for the version"), ): - """Copy a model from job's models directory to the global models registry. - - Supports two modes: - - mode='new': Save as a new model. Uses target_name if provided, otherwise uses the original model_name. - If a model with that name already exists, a timestamped suffix is added. - - mode='existing': Merge into an existing model in the registry. target_name must be provided and must - refer to an existing model. Files from the job model are copied into the existing model directory. - """ + """Copy a model from job's models directory to the global models registry.""" try: - # Secure the source model name model_name_secure = secure_filename(model_name) - # Get source path (job's models directory) job_models_dir = await get_job_models_dir(job_id) source_path = storage.join(job_models_dir, model_name_secure) if not await storage.exists(source_path): raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found in job directory") - # Get the registry directory models_registry_dir = await get_models_dir() if mode == "existing": - # Merge into an existing model + # For mode='existing', the asset is merged into the target group folder. if not target_name: raise HTTPException(status_code=400, detail="target_name is required when mode is 'existing'") target_name_secure = secure_filename(target_name) dest_path = storage.join(models_registry_dir, target_name_secure) if not await storage.exists(dest_path): raise HTTPException(status_code=404, detail=f"Model '{target_name}' not found in registry") - - # Copy files from source into the existing model directory (merge) - try: - await storage.copy_dir(source_path, dest_path) - except Exception as copy_err: - print(f"storage.copy_dir failed: {copy_err}") - - return {"status": "success", "message": f"Model merged into existing registry entry '{target_name_secure}'"} + # When merging into an existing group, asset_name is the target folder itself. + effective_asset_name = target_name_secure else: - # Save as a new model - final_name = secure_filename(target_name) if target_name else model_name_secure - dest_path = storage.join(models_registry_dir, final_name) - - # Check if model already exists in registry and generate a unique name if needed - if await storage.exists(dest_path): - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - final_name = f"{final_name}_{timestamp}" - dest_path = storage.join(models_registry_dir, final_name) - - # Copy the model directory to the registry - try: - await storage.copy_dir(source_path, dest_path) - except Exception as copy_err: - print(f"storage.copy_dir failed: {copy_err}") + # Determine the unique destination folder name (asset_name). + # If the caller explicitly supplied asset_name, enforce uniqueness (409 on conflict). + # Otherwise fall back to target_name / source name and auto-suffix on conflict. + if asset_name: + effective_asset_name = secure_filename(asset_name) + dest_path = storage.join(models_registry_dir, effective_asset_name) + if await storage.exists(dest_path): + raise HTTPException( + status_code=409, + detail=f"A model named '{effective_asset_name}' already exists in the registry. Please choose a different name.", + ) + else: + effective_asset_name = secure_filename(target_name) if target_name else model_name_secure + dest_path = storage.join(models_registry_dir, effective_asset_name) + if await storage.exists(dest_path): + from datetime import datetime + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + effective_asset_name = f"{effective_asset_name}_{timestamp}" + + asyncio.create_task( + _save_model_to_registry( + job_id=job_id, + model_name_secure=model_name_secure, + source_path=source_path, + models_registry_dir=models_registry_dir, + target_name=target_name, + asset_name=effective_asset_name, + mode=mode, + tag=tag, + version_label=version_label, + description=description, + ) + ) - return {"status": "success", "message": f"Model saved to registry as '{final_name}'"} + return {"status": "started", "message": "Model save to registry started"} except HTTPException: raise except Exception as e: - print(f"Error saving model to registry for job {job_id}: {str(e)}") + print(f"Error starting model save to registry for job {job_id}: {str(e)}") import traceback traceback.print_exc() - raise HTTPException(status_code=500, detail=f"Failed to save model to registry: {str(e)}") + raise HTTPException(status_code=500, detail=f"Failed to start model save to registry: {str(e)}") + + +async def _save_model_to_registry( + job_id: str, + model_name_secure: str, + source_path: str, + models_registry_dir: str, + target_name: Optional[str], + asset_name: str, + mode: str, + tag: str, + version_label: str, + description: Optional[str], +): + """Coroutine that performs the copy and creates the version entry.""" + # asset_name is always the unique destination folder name. + # Existence check was already done in the endpoint before dispatching. + dest_path = storage.join(models_registry_dir, asset_name) + await storage.copy_dir(source_path, dest_path) + + group_name = secure_filename(target_name) if target_name else asset_name + version_description = description if description else f"Created from job {job_id}" + await asset_version_service.create_version( + asset_type="model", + group_name=group_name, + asset_id=asset_name, + version_label=version_label, + job_id=job_id, + description=version_description, + tag=tag, + ) + + return asset_name @router.get("/{job_id}/files") diff --git a/api/transformerlab/routers/model.py b/api/transformerlab/routers/model.py index c634bb4d3..7d8ee0c0d 100644 --- a/api/transformerlab/routers/model.py +++ b/api/transformerlab/routers/model.py @@ -642,7 +642,23 @@ async def get_model_prompt_template(model: str): @router.get("/model/list") async def model_local_list(embedding=False): # the model list is a combination of downloaded hugging face models and locally generated models - return await model_helper.list_installed_models(embedding) + models = await model_helper.list_installed_models(embedding) + + # Augment each model with version group info if any + try: + from transformerlab.services import asset_version_service + + group_map = await asset_version_service.get_all_asset_group_map("model") + for model in models: + model_id = model.get("model_id", "") + if model_id in group_map: + model["version_groups"] = group_map[model_id] + else: + model["version_groups"] = [] + except Exception as e: + print(f"Warning: could not fetch model version groups: {e}") + + return models @router.get("/model/provenance/{model_id}") diff --git a/api/transformerlab/services/asset_version_service.py b/api/transformerlab/services/asset_version_service.py new file mode 100644 index 000000000..4cf191293 --- /dev/null +++ b/api/transformerlab/services/asset_version_service.py @@ -0,0 +1,530 @@ +""" +asset_version_service.py + +Service layer for managing versioned groups of models and datasets. +Groups are stored as JSON files under the ``asset_groups/`` directory, completely +separate from the existing ``models/`` and ``datasets/`` directories. + +Directory layout +---------------- +asset_groups/ + models/ + / + index.json - group-level metadata + model_list.json - ordered list of version entries + datasets/ + / + index.json + dataset_list.json + +The actual model / dataset files stay in their original locations. Version +entries only store *references* (``asset_id``) pointing to those assets. +""" + +import json +import uuid +from datetime import datetime, timezone +from typing import Optional + +from lab import storage + +from transformerlab.shared.dirs import get_asset_groups_dir + + +VALID_ASSET_TYPES = {"model", "dataset"} + +# Mapping from asset_type to the filename that holds the version list. +_LIST_FILENAME = { + "model": "model_list.json", + "dataset": "dataset_list.json", +} + + +# --- Internal helpers --------------------------------------------------------- + + +def _validate_asset_type(asset_type: str) -> None: + if asset_type not in VALID_ASSET_TYPES: + raise ValueError(f"asset_type must be one of {VALID_ASSET_TYPES}, got '{asset_type}'") + + +async def _group_dir(asset_type: str, group_name: str) -> str: + """Return the directory path for a specific group, creating parents as needed.""" + root = await get_asset_groups_dir() + # asset_type plural form for directory name + type_dir = storage.join(root, f"{asset_type}s") + await storage.makedirs(type_dir, exist_ok=True) + path = storage.join(type_dir, group_name) + await storage.makedirs(path, exist_ok=True) + return path + + +async def _read_json(path: str, default=None): + """Read and parse a JSON file from storage. Returns *default* if missing.""" + try: + if not await storage.exists(path): + return default + async with await storage.open(path, "r", encoding="utf-8") as f: + text = await f.read() + return json.loads(text) + except Exception: + return default + + +async def _write_json(path: str, data) -> None: + """Serialise *data* to JSON and write it to *path*.""" + text = json.dumps(data, indent=2, default=str) + async with await storage.open(path, "w", encoding="utf-8") as f: + await f.write(text) + + +async def _read_index(asset_type: str, group_name: str) -> dict: + gdir = await _group_dir(asset_type, group_name) + path = storage.join(gdir, "index.json") + data = await _read_json(path, default=None) + if data is None: + data = { + "name": group_name, + "created_at": datetime.now(timezone.utc).isoformat(), + "description": "", + "cover_image": None, + } + await _write_json(path, data) + return data + + +async def _write_index(asset_type: str, group_name: str, data: dict) -> None: + gdir = await _group_dir(asset_type, group_name) + path = storage.join(gdir, "index.json") + await _write_json(path, data) + + +async def _read_versions(asset_type: str, group_name: str) -> list[dict]: + """Return the list of version dicts for a group (in append order).""" + gdir = await _group_dir(asset_type, group_name) + filename = _LIST_FILENAME[asset_type] + path = storage.join(gdir, filename) + data = await _read_json(path, default=None) + if data is None: + return [] + return data.get("versions", []) + + +async def _write_versions(asset_type: str, group_name: str, versions: list[dict]) -> None: + gdir = await _group_dir(asset_type, group_name) + filename = _LIST_FILENAME[asset_type] + path = storage.join(gdir, filename) + await _write_json(path, {"versions": versions}) + + +def _version_to_dict(v: dict, asset_type: str, group_name: str) -> dict: + """Normalise a raw version dict into the shape expected by callers.""" + return { + "id": v.get("id", ""), + "asset_type": asset_type, + "group_name": group_name, + "version_label": v.get("version_label", ""), + "asset_id": v.get("asset_id", ""), + "tag": v.get("tag"), + "job_id": v.get("job_id"), + "description": v.get("description"), + "title": v.get("title"), + "long_description": v.get("long_description"), + "cover_image": v.get("cover_image"), + "evals": v.get("evals"), + "metadata": v.get("metadata"), + "created_at": v.get("created_at"), + } + + +# --- Public API --------------------------------------------------------------- + + +async def create_version( + *, + asset_type: str, + group_name: str, + asset_id: str, + version_label: str = "v1", + job_id: Optional[str] = None, + description: Optional[str] = None, + title: Optional[str] = None, + long_description: Optional[str] = None, + cover_image: Optional[str] = None, + evals: Optional[dict] = None, + extra_metadata: Optional[dict] = None, + tag: Optional[str] = "latest", +) -> dict: + """Create a new version in a group. + + By default the new version is tagged ``'latest'`` and any previous holder + of that tag in the same group has the tag cleared. + + Returns: + A dict representation of the newly created version entry. + """ + _validate_asset_type(asset_type) + + # Ensure group index exists + await _read_index(asset_type, group_name) + + versions = await _read_versions(asset_type, group_name) + + # Clear the tag from any other version in this group (one holder per tag) + if tag is not None: + for v in versions: + if v.get("tag") == tag: + v["tag"] = None + + new_entry: dict = { + "id": str(uuid.uuid4()), + "version_label": version_label, + "tag": tag, + "asset_id": asset_id, + "job_id": job_id, + "created_at": datetime.now(timezone.utc).isoformat(), + "title": title, + "description": description, + "long_description": long_description, + "cover_image": cover_image, + "evals": evals, + "metadata": extra_metadata, + } + + versions.append(new_entry) + await _write_versions(asset_type, group_name, versions) + + return _version_to_dict(new_entry, asset_type, group_name) + + +async def list_groups(asset_type: str) -> list[dict]: + """List all groups for a given asset type with summary info.""" + _validate_asset_type(asset_type) + + root = await get_asset_groups_dir() + type_dir = storage.join(root, f"{asset_type}s") + + # If the type directory doesn't exist yet, return empty list + if not await storage.exists(type_dir): + return [] + + groups: list[dict] = [] + try: + entries = await storage.ls(type_dir, detail=False) + except Exception: + return [] + + for entry in entries: + # entry may be a full path - take the basename + entry_name = entry.rsplit("/", 1)[-1] if "/" in str(entry) else str(entry) + if not entry_name or entry_name.startswith("."): + continue + + group_name = entry_name + versions = await _read_versions(asset_type, group_name) + tags = [v["tag"] for v in versions if v.get("tag")] + + groups.append( + { + "group_name": group_name, + "asset_type": asset_type, + "version_count": len(versions), + "latest_version_label": versions[-1].get("version_label") if versions else None, + "tags": tags, + } + ) + + groups.sort(key=lambda g: g["group_name"]) + return groups + + +async def list_versions(asset_type: str, group_name: str) -> list[dict]: + """List all versions in a group, newest first.""" + _validate_asset_type(asset_type) + versions = await _read_versions(asset_type, group_name) + # Return newest first (reverse of append order) + return [_version_to_dict(v, asset_type, group_name) for v in reversed(versions)] + + +async def get_version(asset_type: str, group_name: str, version_label: str) -> Optional[dict]: + """Get a specific version by its label.""" + _validate_asset_type(asset_type) + versions = await _read_versions(asset_type, group_name) + for v in versions: + if v.get("version_label") == version_label: + return _version_to_dict(v, asset_type, group_name) + return None + + +async def get_version_by_id(asset_type: str, group_name: str, version_id: str) -> Optional[dict]: + """Get a specific version by its UUID.""" + _validate_asset_type(asset_type) + versions = await _read_versions(asset_type, group_name) + for v in versions: + if v.get("id") == version_id: + return _version_to_dict(v, asset_type, group_name) + return None + + +async def update_version( + asset_type: str, + group_name: str, + version_label: str, + *, + description: Optional[str] = ..., + title: Optional[str] = ..., + long_description: Optional[str] = ..., + cover_image: Optional[str] = ..., + evals: Optional[dict] = ..., + extra_metadata: Optional[dict] = ..., + tag: Optional[str] = ..., +) -> Optional[dict]: + """Update mutable fields on a specific version. + + Uses sentinel default (``...``) so callers can distinguish between + "not provided" and "explicitly set to None". + """ + _validate_asset_type(asset_type) + + versions = await _read_versions(asset_type, group_name) + target = None + for v in versions: + if v.get("version_label") == version_label: + target = v + break + + if target is None: + return None + + updatable = { + "description": description, + "title": title, + "long_description": long_description, + "cover_image": cover_image, + "evals": evals, + "metadata": extra_metadata, + "tag": tag, + } + + for field, value in updatable.items(): + if value is ...: + continue # not provided by caller + + if field == "tag" and value is not None: + # Clear this tag from any other version in the group + for v in versions: + if v is not target and v.get("tag") == value: + v["tag"] = None + + target[field] = value + + await _write_versions(asset_type, group_name, versions) + return _version_to_dict(target, asset_type, group_name) + + +async def resolve_by_tag(asset_type: str, group_name: str, tag: str = "latest") -> Optional[dict]: + """Resolve a version by its tag.""" + _validate_asset_type(asset_type) + + versions = await _read_versions(asset_type, group_name) + for v in versions: + if v.get("tag") == tag: + return _version_to_dict(v, asset_type, group_name) + return None + + +async def resolve( + asset_type: str, + group_name: str, + tag: Optional[str] = None, + version_label: Optional[str] = None, +) -> Optional[dict]: + """Resolve a specific version of a group. + + Resolution priority: + 1. If version_label is provided, return that exact version. + 2. If tag is provided, return the version with that tag. + 3. Otherwise, return the version tagged 'latest'. + 4. If no 'latest' tag exists, return the most recently added version. + """ + _validate_asset_type(asset_type) + + if version_label is not None: + return await get_version(asset_type, group_name, version_label) + + if tag is not None: + return await resolve_by_tag(asset_type, group_name, tag) + + # Default: try 'latest' tag first + result = await resolve_by_tag(asset_type, group_name, "latest") + if result: + return result + + # Fallback: most recently added version (last in list) + versions = await _read_versions(asset_type, group_name) + if versions: + return _version_to_dict(versions[-1], asset_type, group_name) + return None + + +async def set_tag(asset_type: str, group_name: str, version_label: str, tag: str) -> Optional[dict]: + """Set a tag on a specific version. + + Clears the tag from any other version in the same group first. + """ + _validate_asset_type(asset_type) + + versions = await _read_versions(asset_type, group_name) + + # Clear tag from all versions + for v in versions: + if v.get("tag") == tag: + v["tag"] = None + + # Assign to target + target = None + for v in versions: + if v.get("version_label") == version_label: + v["tag"] = tag + target = v + break + + if target is None: + return None + + await _write_versions(asset_type, group_name, versions) + return _version_to_dict(target, asset_type, group_name) + + +async def clear_tag(asset_type: str, group_name: str, version_label: str) -> Optional[dict]: + """Remove the tag from a specific version.""" + _validate_asset_type(asset_type) + + versions = await _read_versions(asset_type, group_name) + target = None + for v in versions: + if v.get("version_label") == version_label: + v["tag"] = None + target = v + break + + if target is None: + return None + + await _write_versions(asset_type, group_name, versions) + return _version_to_dict(target, asset_type, group_name) + + +async def delete_version(asset_type: str, group_name: str, version_label: str) -> bool: + """Delete a specific version from the registry. + + Returns True if the version existed and was deleted. + Does NOT delete the underlying filesystem asset. + """ + _validate_asset_type(asset_type) + + versions = await _read_versions(asset_type, group_name) + new_versions = [v for v in versions if v.get("version_label") != version_label] + + if len(new_versions) == len(versions): + return False # not found + + await _write_versions(asset_type, group_name, new_versions) + + # If group is now empty, clean up its directory + if not new_versions: + await _remove_group_dir(asset_type, group_name) + + return True + + +async def delete_group(asset_type: str, group_name: str) -> int: + """Delete all versions in a group. + + Returns the number of versions deleted. + Does NOT delete the underlying filesystem assets. + """ + _validate_asset_type(asset_type) + + versions = await _read_versions(asset_type, group_name) + count = len(versions) + + if count > 0: + await _remove_group_dir(asset_type, group_name) + + return count + + +async def _remove_group_dir(asset_type: str, group_name: str) -> None: + """Remove the group directory and all its JSON files.""" + root = await get_asset_groups_dir() + type_dir = storage.join(root, f"{asset_type}s") + gdir = storage.join(type_dir, group_name) + try: + if await storage.exists(gdir): + await storage.rm(gdir, recursive=True) + except Exception: + pass + + +async def get_groups_for_asset(asset_type: str, asset_id: str) -> list[dict]: + """Find all groups that contain a specific asset_id. + + Useful for showing version badges on the model/dataset list views. + """ + _validate_asset_type(asset_type) + + root = await get_asset_groups_dir() + type_dir = storage.join(root, f"{asset_type}s") + + if not await storage.exists(type_dir): + return [] + + results: list[dict] = [] + try: + entries = await storage.ls(type_dir, detail=False) + except Exception: + return [] + + for entry in entries: + group_name = entry.rsplit("/", 1)[-1] if "/" in str(entry) else str(entry) + if not group_name or group_name.startswith("."): + continue + versions = await _read_versions(asset_type, group_name) + for v in versions: + if v.get("asset_id") == asset_id: + results.append(_version_to_dict(v, asset_type, group_name)) + + return results + + +async def get_all_asset_group_map(asset_type: str) -> dict[str, list[dict]]: + """Build a map of asset_id -> list of group memberships. + + This is used by the frontend to efficiently annotate list views + with version/group information without N+1 queries. + """ + _validate_asset_type(asset_type) + + root = await get_asset_groups_dir() + type_dir = storage.join(root, f"{asset_type}s") + + if not await storage.exists(type_dir): + return {} + + mapping: dict[str, list[dict]] = {} + try: + entries = await storage.ls(type_dir, detail=False) + except Exception: + return {} + + for entry in entries: + group_name = entry.rsplit("/", 1)[-1] if "/" in str(entry) else str(entry) + if not group_name or group_name.startswith("."): + continue + versions = await _read_versions(asset_type, group_name) + for v in versions: + d = _version_to_dict(v, asset_type, group_name) + mapping.setdefault(d["asset_id"], []).append(d) + + return mapping diff --git a/api/transformerlab/shared/dirs.py b/api/transformerlab/shared/dirs.py index 4031a8d28..fe610c5e2 100644 --- a/api/transformerlab/shared/dirs.py +++ b/api/transformerlab/shared/dirs.py @@ -76,5 +76,24 @@ async def initialize_dirs(): GALLERIES_LOCAL_FALLBACK_DIR = os.path.join(TFL_SOURCE_CODE_DIR, "transformerlab/galleries/") +async def get_asset_groups_dir() -> str: + """Return the root directory for filesystem-based asset groups. + + Layout: + /asset_groups/ + models//index.json, model_list.json + datasets//index.json, dataset_list.json + + The directory is created on first access. + """ + from lab import storage + from lab.dirs import get_workspace_dir + + workspace = await get_workspace_dir() + path = storage.join(workspace, "asset_groups") + await storage.makedirs(path, exist_ok=True) + return path + + # TEMPORARY: We want to move jobs back into the root directory instead of under experiment # But for now we need to leave this here. diff --git a/src/renderer/components/Data/Data.tsx b/src/renderer/components/Data/Data.tsx index bc14c6457..205f2f249 100644 --- a/src/renderer/components/Data/Data.tsx +++ b/src/renderer/components/Data/Data.tsx @@ -1,49 +1,79 @@ /* eslint-disable jsx-a11y/anchor-is-valid */ import Sheet from '@mui/joy/Sheet'; import { Tab, TabList, TabPanel, Tabs } from '@mui/joy'; -import { StoreIcon } from 'lucide-react'; +import { StoreIcon, LayersIcon } from 'lucide-react'; import DataStore from './DataStore'; import LocalDatasets from './LocalDatasets'; import GeneratedDatasets from './GeneratedDatasets'; +import DatasetRegistry from './DatasetRegistry'; -export default function Data() { +export default function Data({ tab = 'local' }) { const isLocalMode = window?.platform?.multiuser !== true; + return ( - - - Local Datasets - Generated Datasets - {isLocalMode && ( + {isLocalMode ? ( + + + Local Datasets + Generated Datasets   Dataset Store - )} - - - - - - - - {isLocalMode && ( + + + + + + + - )} - + + ) : ( + + + + +   Dataset Registry + + + + + + + )} ); } diff --git a/src/renderer/components/Data/DatasetCard.tsx b/src/renderer/components/Data/DatasetCard.tsx index 538e60175..9b2b002f0 100644 --- a/src/renderer/components/Data/DatasetCard.tsx +++ b/src/renderer/components/Data/DatasetCard.tsx @@ -22,6 +22,8 @@ import { useAPI } from 'renderer/lib/transformerlab-api-sdk'; import PreviewDatasetModal from './PreviewDatasetModal'; import DatasetInfoModal from './DatasetInfoModal'; import EditDatasetModal from './EditDatasetModal'; +import VersionGroupChip from '../Shared/VersionGroupChip'; +import AssetVersionsDrawer from '../Shared/AssetVersionsDrawer'; export default function DatasetCard({ name, @@ -33,6 +35,7 @@ export default function DatasetCard({ parentMutate, local, friendlyName = null, + versionGroups = [], }) { const [installing, setInstalling] = useState(null); const [previewModalState, setPreviewModalState] = useState({ @@ -42,6 +45,10 @@ export default function DatasetCard({ }); const [datasetInfoModalOpen, setDatasetInfoModalOpen] = useState(false); const [editDatasetModalOpen, setEditDatasetModalOpen] = useState(false); + const [versionDrawer, setVersionDrawer] = useState<{ + open: boolean; + groupName: string; + }>({ open: false, groupName: '' }); const { data: datasetInfo } = useAPI('datasets', ['info'], { datasetId: name, @@ -89,6 +96,16 @@ export default function DatasetCard({ {location === 'huggingfacehub' && ' 🤗'} {location === 'local' && ' '} + {versionGroups && versionGroups.length > 0 && ( +
+ + setVersionDrawer({ open: true, groupName }) + } + /> +
+ )}
{description} @@ -235,6 +252,12 @@ export default function DatasetCard({ )} + setVersionDrawer({ open: false, groupName: '' })} + assetType="dataset" + groupName={versionDrawer.groupName} + /> ); } diff --git a/src/renderer/components/Data/DatasetRegistry.tsx b/src/renderer/components/Data/DatasetRegistry.tsx new file mode 100644 index 000000000..a0d3730ca --- /dev/null +++ b/src/renderer/components/Data/DatasetRegistry.tsx @@ -0,0 +1,923 @@ +/** + * DatasetRegistry.tsx + * + * Displays asset-version groups for datasets (from the asset_versions API). + * Each group renders as an expandable accordion row showing its versions. + * This is the "Dataset Registry" tab inside Data. + * + * The UI mirrors the Local Models table: search bar, filters (disabled), + * refresh button, and a clean table of versions per group. + */ + +import { useState } from 'react'; +import { + Accordion, + AccordionDetails, + AccordionGroup, + AccordionSummary, + Box, + Chip, + CircularProgress, + Divider, + Drawer, + DialogTitle, + FormControl, + FormLabel, + IconButton, + Input, + ModalClose, + Option, + Select, + Sheet, + Skeleton, + Stack, + Table, + Tooltip, + Typography, +} from '@mui/joy'; +import { + BriefcaseIcon, + ChevronDownIcon, + DatabaseIcon, + InfoIcon, + RotateCcwIcon, + SearchIcon, + Trash2Icon, + XIcon, +} from 'lucide-react'; +import Markdown from 'react-markdown'; +import { useSWRWithAuth as useSWR } from 'renderer/lib/authContext'; +import { fetchWithAuth } from 'renderer/lib/authContext'; +import * as chatAPI from '../../lib/transformerlab-api-sdk'; +import { fetcher } from '../../lib/transformerlab-api-sdk'; + +// ─── Types ─────────────────────────────────────────────────────────────────── + +interface VersionEntry { + id: string; + asset_type: string; + group_name: string; + version_label: string; + asset_id: string; + tag: string | null; + job_id: string | null; + description: string | null; + title: string | null; + long_description: string | null; + cover_image: string | null; + evals: Record | null; + metadata: Record | null; + created_at: string | null; +} + +interface GroupSummary { + group_name: string; + asset_type: string; + version_count: number; + latest_version_label: string | null; + latest_tag: string | null; + latest_created_at: string | null; +} + +// ─── Tag colours ───────────────────────────────────────────────────────────── + +const TAG_COLORS: Record< + string, + 'success' | 'primary' | 'warning' | 'neutral' +> = { + latest: 'primary', + production: 'success', + draft: 'warning', +}; + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +function formatDate(isoString: string | null): string { + if (!isoString) return '—'; + try { + const d = new Date(isoString); + return d.toLocaleDateString(undefined, { + year: 'numeric', + month: 'short', + day: 'numeric', + hour: '2-digit', + minute: '2-digit', + }); + } catch { + return isoString; + } +} + +function formatRelativeDate(isoString: string | null): string { + if (!isoString) return '—'; + try { + const d = new Date(isoString); + const now = new Date(); + const diffMs = now.getTime() - d.getTime(); + const diffDays = Math.floor(diffMs / (1000 * 60 * 60 * 24)); + if (diffDays === 0) return 'Today'; + if (diffDays === 1) return 'Yesterday'; + if (diffDays < 30) return `${diffDays}d ago`; + if (diffDays < 365) return `${Math.floor(diffDays / 30)}mo ago`; + return `${Math.floor(diffDays / 365)}y ago`; + } catch { + return isoString; + } +} + +// ─── Skeleton loader (matches LocalModelsTable pattern) ────────────────────── + +function RegistrySkeleton() { + return ( + <> + *': { minWidth: { xs: '120px', md: '160px' } }, + }} + > + + + + + + {[...Array(6)].map((_, idx) => ( + + ))} + + + + ); +} + +// ─── Version Info Drawer ───────────────────────────────────────────────────── + +function VersionInfoDrawer({ + open, + onClose, + entry, +}: { + open: boolean; + onClose: () => void; + entry: VersionEntry | null; +}) { + if (!entry) return null; + + return ( + + + + + + Version Details: {entry.version_label} + + + + + + + + + {/* Title */} + {entry.title && ( + + + Title + + {entry.title} + + )} + + {/* Description */} + {entry.description && ( + + + Description + + {entry.description} + + )} + + {/* Long description (markdown) */} + {entry.long_description && ( + + + Details + + + {entry.long_description} + + + )} + + {/* Cover image */} + {entry.cover_image && ( + + + Cover Image + + Cover + + )} + + {/* Dataset ID */} + + + Dataset ID + + + {entry.asset_id} + + + + {/* Tag */} + + + Tag + + {entry.tag ? ( + + {entry.tag} + + ) : ( + + — + + )} + + + {/* Created */} + + + Created + + + {formatDate(entry.created_at)} + + + + {/* Source Job */} + + + Source Job + + {entry.job_id ? ( + + +  Job {entry.job_id} + + ) : ( + + — + + )} + + + {/* Evals */} + {entry.evals && Object.keys(entry.evals).length > 0 && ( + + + Evaluations + + + + + + + + + + {Object.entries(entry.evals).map(([key, val]) => ( + + + + + ))} + +
MetricValue
+ + {key} + + + {String(val)} +
+
+ )} +
+
+
+ ); +} + +// ─── Version row (inline in the accordion) ─────────────────────────────────── + +function VersionRow({ + v, + updatingVersion, + onSetTag, + onClearTag, + onDelete, + onInfo, +}: { + v: VersionEntry; + updatingVersion: string | null; + onSetTag: (versionLabel: string, tag: string) => void; + onClearTag: (versionLabel: string) => void; + onDelete: (versionLabel: string) => void; + onInfo: (version: VersionEntry) => void; +}) { + return ( + + {/* Name / title */} + + + + {v.title || v.asset_id} + + + + {/* Dataset ID */} + + + {v.asset_id} + + + {/* Version */} + + + {v.version_label} + + + {/* Tag */} + + {updatingVersion === v.version_label ? ( + + ) : v.tag ? ( + onClearTag(v.version_label)} + sx={{ '--IconButton-size': '18px', ml: 0.5 }} + > + + + } + > + {v.tag} + + ) : ( + + )} + + {/* Job */} + + {v.job_id ? ( + + + +  {String(v.job_id).slice(0, 6)} + + + ) : ( + + — + + )} + + {/* Created */} + + + {formatRelativeDate(v.created_at)} + + + {/* Info + Delete (inline, no Actions header) */} + + onInfo(v)} + /> +   + onDelete(v.version_label)} + /> + + + ); +} + +// ─── Expanded group versions table ─────────────────────────────────────────── + +function GroupVersionsTable({ + groupName, + mutateGroups, + onOpenInfo, +}: { + groupName: string; + mutateGroups: () => void; + onOpenInfo: (v: VersionEntry) => void; +}) { + const [updatingVersion, setUpdatingVersion] = useState(null); + const assetType = 'dataset'; + + const { + data: versions, + isLoading, + mutate, + } = useSWR( + chatAPI.Endpoints.AssetVersions.ListVersions(assetType, groupName), + fetcher, + ); + + const handleSetTag = async (versionLabel: string, tag: string) => { + setUpdatingVersion(versionLabel); + try { + await fetchWithAuth( + chatAPI.Endpoints.AssetVersions.SetTag( + assetType, + groupName, + versionLabel, + ), + { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ tag }), + }, + ); + mutate(); + mutateGroups(); + } catch (error) { + console.error('Failed to set tag:', error); + } finally { + setUpdatingVersion(null); + } + }; + + const handleClearTag = async (versionLabel: string) => { + setUpdatingVersion(versionLabel); + try { + await fetchWithAuth( + chatAPI.Endpoints.AssetVersions.ClearTag( + assetType, + groupName, + versionLabel, + ), + { method: 'DELETE' }, + ); + mutate(); + mutateGroups(); + } catch (error) { + console.error('Failed to clear tag:', error); + } finally { + setUpdatingVersion(null); + } + }; + + const handleDeleteVersion = async (versionLabel: string) => { + if ( + !window.confirm( + `Delete version ${versionLabel} from group "${groupName}"? This will not delete the underlying dataset.`, + ) + ) { + return; + } + setUpdatingVersion(versionLabel); + try { + await fetchWithAuth( + chatAPI.Endpoints.AssetVersions.DeleteVersion( + assetType, + groupName, + versionLabel, + ), + { method: 'DELETE' }, + ); + mutate(); + mutateGroups(); + } catch (error) { + console.error('Failed to delete version:', error); + } finally { + setUpdatingVersion(null); + } + }; + + const versionList: VersionEntry[] = Array.isArray(versions) ? versions : []; + + if (isLoading) { + return ( + + + + ); + } + + if (versionList.length === 0) { + return ( + + No versions in this group. + + ); + } + + return ( + + theme.vars.palette.background.level1, + '--Table-headerUnderlineThickness': '1px', + '--TableRow-hoverBackground': (theme: any) => + theme.vars.palette.background.level1, + '& thead th': { textAlign: 'left' }, + '& tbody td': { verticalAlign: 'middle' }, + }} + > + + + + + + + + + + + + + {versionList.map((v) => ( + + ))} + +
NameDataset IDVersionTagJobCreated
+ ); +} + +// ─── Main component ────────────────────────────────────────────────────────── + +export default function DatasetRegistry() { + const [expandedGroups, setExpandedGroups] = useState>(new Set()); + const [searchText, setSearchText] = useState(''); + const [infoDrawerEntry, setInfoDrawerEntry] = useState( + null, + ); + + const { + data: groups, + isLoading, + isError, + mutate: mutateGroups, + } = useSWR(chatAPI.Endpoints.AssetVersions.ListGroups('dataset'), fetcher); + + const handleDeleteGroup = async (groupName: string) => { + if ( + !window.confirm( + `Delete group "${groupName}" and ALL its versions? The underlying datasets will not be deleted.`, + ) + ) { + return; + } + try { + await fetchWithAuth( + chatAPI.Endpoints.AssetVersions.DeleteGroup('dataset', groupName), + { method: 'DELETE' }, + ); + mutateGroups(); + } catch (err) { + console.error('Failed to delete group:', err); + } + }; + + const toggleGroup = (groupName: string) => { + setExpandedGroups((prev) => { + const next = new Set(prev); + if (next.has(groupName)) { + next.delete(groupName); + } else { + next.add(groupName); + } + return next; + }); + }; + + if (isLoading) return ; + if (isError) { + return ( + + + Failed to load dataset registry groups. + + + ); + } + + const groupList: GroupSummary[] = Array.isArray(groups) ? groups : []; + + // Filter groups by search text + const filteredGroups = groupList.filter((g) => { + const search = searchText.toLowerCase(); + if (search && !g.group_name.toLowerCase().includes(search)) { + return false; + } + return true; + }); + + return ( + + {/* ── Top bar: matches LocalModelsTable ── */} + *': { + minWidth: { + xs: '120px', + md: '160px', + }, + }, + }} + > + +   + setSearchText(e.target.value)} + startDecorator={} + /> + + + +   + mutateGroups()} + aria-label="Refresh datasets" + > + +   Refresh Datasets + + + + + {/* ── Content ── */} + + {filteredGroups.length === 0 ? ( + + + + {searchText + ? 'No dataset groups match your search.' + : 'No dataset groups yet.'} + + {!searchText && ( + + Publish a dataset from the Jobs page to create your first group. + + )} + + ) : ( + + {filteredGroups.map((group) => { + const isExpanded = expandedGroups.has(group.group_name); + return ( + toggleGroup(group.group_name)} + sx={{ + borderRadius: 'md', + border: '1px solid', + borderColor: 'divider', + }} + > + } + sx={{ px: 2, py: 1.5 }} + > + + {/* Left side: Name + badges */} + + + + {group.group_name} + + + {group.version_count} version + {group.version_count !== 1 ? 's' : ''} + + {group.latest_tag && ( + + {group.latest_tag} + + )} + + + {/* Right side: last updated */} + + {formatRelativeDate(group.latest_created_at)} + + + + + + {isExpanded && ( + setInfoDrawerEntry(v)} + /> + )} + + + ); + })} + + )} + + + {/* ── Version Info Drawer ── */} + setInfoDrawerEntry(null)} + entry={infoDrawerEntry} + /> + + ); +} diff --git a/src/renderer/components/Data/GeneratedDatasets.tsx b/src/renderer/components/Data/GeneratedDatasets.tsx index 41a71d120..3e86f5d25 100644 --- a/src/renderer/components/Data/GeneratedDatasets.tsx +++ b/src/renderer/components/Data/GeneratedDatasets.tsx @@ -129,6 +129,7 @@ export default function GeneratedDatasets() { downloaded={true} local={true} parentMutate={mutate} + versionGroups={row?.version_groups || []} /> ))} diff --git a/src/renderer/components/Data/LocalDatasets.tsx b/src/renderer/components/Data/LocalDatasets.tsx index 39595a5b0..8056e9e08 100644 --- a/src/renderer/components/Data/LocalDatasets.tsx +++ b/src/renderer/components/Data/LocalDatasets.tsx @@ -143,6 +143,7 @@ export default function LocalDatasets() { downloaded={true} local={true} parentMutate={mutate} + versionGroups={row?.version_groups || []} /> ))} diff --git a/src/renderer/components/Experiment/Tasks/SaveToRegistryDialog.tsx b/src/renderer/components/Experiment/Tasks/SaveToRegistryDialog.tsx index e2eb21e13..d93d0d2dc 100644 --- a/src/renderer/components/Experiment/Tasks/SaveToRegistryDialog.tsx +++ b/src/renderer/components/Experiment/Tasks/SaveToRegistryDialog.tsx @@ -13,8 +13,40 @@ import { Radio, Autocomplete, Box, + Chip, + Divider, + Option, + Select, + Textarea, } from '@mui/joy'; -import { Save } from 'lucide-react'; +import { LayersIcon, Save, TagIcon } from 'lucide-react'; +import { useSWRWithAuth as useSWR } from 'renderer/lib/authContext'; +import * as chatAPI from '../../../lib/transformerlab-api-sdk'; +import { fetcher } from '../../../lib/transformerlab-api-sdk'; + +// ─── Types ─────────────────────────────────────────────────────────────────── + +interface GroupSummary { + group_name: string; + asset_type: string; + version_count: number; + latest_version_label: string | null; +} + +export interface SaveVersionInfo { + /** The group name (either new or existing) */ + groupName: string; + /** Unique name for the asset in the registry folder */ + assetName: string; + /** 'new' = create a new group, 'existing' = add version to existing group */ + mode: 'new' | 'existing'; + /** Tag to assign to the new version */ + tag: string; + /** User-defined version label (e.g. 'v1', 'march-run') */ + versionLabel: string; + /** Human-readable description for the version */ + description: string; +} interface SaveToRegistryDialogProps { open: boolean; @@ -28,9 +60,28 @@ interface SaveToRegistryDialogProps { /** Whether the save is in progress */ saving: boolean; /** Called when the user confirms the save */ - onSave: (targetName: string, mode: 'new' | 'existing') => void; + onSave: (info: SaveVersionInfo) => void; + /** Job ID that produced this asset (optional, for display) */ + jobId?: string | number; + /** External error message to display on the asset name field (e.g. name already exists) */ + assetNameError?: string | null; } +// ─── Constants ─────────────────────────────────────────────────────────────── + +const TAG_COLORS: Record< + string, + 'success' | 'primary' | 'warning' | 'neutral' +> = { + latest: 'primary', + production: 'success', + draft: 'warning', +}; + +const TAG_OPTIONS = ['latest', 'production', 'draft']; + +// ─── Component ─────────────────────────────────────────────────────────────── + export default function SaveToRegistryDialog({ open, onClose, @@ -39,86 +90,255 @@ export default function SaveToRegistryDialog({ existingNames, saving, onSave, + jobId, + assetNameError: externalAssetNameError, }: SaveToRegistryDialogProps) { const [mode, setMode] = useState<'new' | 'existing'>('new'); const [newName, setNewName] = useState(sourceName); + const [assetName, setAssetName] = useState(sourceName); + const [assetNameError, setAssetNameError] = useState(null); const [existingTarget, setExistingTarget] = useState(null); + const [tag, setTag] = useState('latest'); + const [versionLabel, setVersionLabel] = useState('v1'); + const [description, setDescription] = useState(''); + + // Fetch existing groups from asset_versions API + const { data: groupsData } = useSWR( + open ? chatAPI.Endpoints.AssetVersions.ListGroups(type) : null, + fetcher, + ); + const groups: GroupSummary[] = Array.isArray(groupsData) ? groupsData : []; + const groupNames = groups.map((g) => g.group_name); + + // Find selected group info for "next version" display + const selectedGroup = + mode === 'existing' && existingTarget + ? groups.find((g) => g.group_name === existingTarget) + : null; + const latestVersionLabel = selectedGroup?.latest_version_label ?? null; // Reset state when opening useEffect(() => { if (open) { setMode('new'); setNewName(sourceName); + setAssetName(sourceName); + setAssetNameError(null); setExistingTarget(null); + setTag('latest'); + setVersionLabel('v1'); + setDescription(''); } }, [open, sourceName]); + // Sync external asset name error from parent (e.g. 409 conflict response) + useEffect(() => { + if (externalAssetNameError) { + setAssetNameError(externalAssetNameError); + } + }, [externalAssetNameError]); + const typeLabel = type === 'dataset' ? 'Dataset' : 'Model'; const canSave = - mode === 'new' + (mode === 'new' ? newName.trim().length > 0 - : existingTarget !== null && existingTarget.trim().length > 0; + : existingTarget !== null && existingTarget.trim().length > 0) && + assetName.trim().length > 0; const handleSubmit = () => { if (!canSave) return; - const targetName = mode === 'new' ? newName.trim() : existingTarget!; - onSave(targetName, mode); + setAssetNameError(null); + const groupName = mode === 'new' ? newName.trim() : existingTarget!; + onSave({ + groupName, + assetName: assetName.trim(), + mode, + tag, + versionLabel: versionLabel.trim() || 'v1', + description: + description.trim() || `Created from job ${jobId ?? 'unknown'}`, + }); }; return ( - + - Save {typeLabel} to Registry + + + Publish {typeLabel} to Registry + - Choose how to publish {sourceName} to the{' '} + Publish {sourceName} as a new versioned entry in the{' '} {typeLabel.toLowerCase()} registry. + {/* ── Group selection ── */} setMode(e.target.value as 'new' | 'existing')} sx={{ gap: 2 }} > - {/* Option 1: Save as new */} + {/* Option 1: Create new group */} - + {mode === 'new' && ( - {typeLabel} name + Group name setNewName(e.target.value)} - placeholder={`Enter a name for the new ${typeLabel.toLowerCase()}`} + placeholder={`e.g. my-${typeLabel.toLowerCase()}`} autoFocus /> + + This will be version 1 in the new group. + )} - {/* Option 2: Add to existing */} + {/* Option 2: Add version to existing group */} {mode === 'existing' && ( - Select existing {typeLabel.toLowerCase()} + Select group setExistingTarget(value)} - placeholder={`Search ${typeLabel.toLowerCase()}s…`} + placeholder={`Search groups…`} autoFocus /> + {selectedGroup && ( + + Currently has {selectedGroup.version_count} version + {selectedGroup.version_count !== 1 ? 's' : ''} + {latestVersionLabel + ? ` (latest: ${latestVersionLabel})` + : ''} + . + + )} )} + + + {/* ── Version metadata ── */} + + Version Details + + + + {/* Asset name (unique folder name in the registry) */} + + {typeLabel} Name + { + setAssetName(e.target.value); + setAssetNameError(null); + }} + placeholder={`Unique name for this ${typeLabel.toLowerCase()}`} + color={assetNameError ? 'danger' : undefined} + /> + {assetNameError ? ( + + {assetNameError} + + ) : ( + + The unique folder name used to store the{' '} + {typeLabel.toLowerCase()} in the registry. + + )} + + + {/* Version label */} + + Version Label + setVersionLabel(e.target.value)} + placeholder="e.g. v1, v1.2.3, march-run" + /> + + A human-readable label for this version. + + + + {/* Tag selector */} + + + + + Tag + + + + + The tag will be moved from any version that currently holds it. + + + + {/* Description */} + + Description +