-
-
Notifications
You must be signed in to change notification settings - Fork 502
Add/model dataset group #1481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add/model dataset group #1481
Changes from 81 commits
b7f3e42
0085f7a
ff23303
cea7182
2932b65
a4a826d
929bddf
361ee4a
cf65d08
4d9b36c
7216854
048fad2
bc1d8c4
af276e8
4cc8e42
47e604e
df7059c
bd82eb0
c4ae4ee
d3b04b5
ff09e24
64f93fc
95654d4
e134fec
c9beb7f
cf0b63f
a768a2e
6a81ec5
486ba4d
5e5102a
6430b5d
3609538
3730526
d3f1820
de09a30
81d67e3
47cc2e6
381e4ea
5010b4f
0f1bf42
07e13ad
ba28351
359c223
6d39e5b
f6fddf8
4983e88
cf2d46c
9d47020
a532aa5
a232fda
573f5ca
5fe5ccf
65d42aa
5495c7f
a5a629d
db6ea3e
b72e103
9e00a8a
c1131b1
008bd0e
890dc37
576897a
e45e694
97f0809
d264f0c
8c9f65b
c6cfb5d
b8a4668
9d5e99c
8a86bde
d62d451
454be16
1e208a4
b6a540b
2e5109a
8af076f
6e0a768
1571740
a89875d
4bfac62
a93f186
dc283d0
435b474
8c0ff17
1ffac84
d88d566
03df35b
6cfcd2d
887f040
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,237 @@ | ||||||||||||
| """ | ||||||||||||
| asset_versions.py | ||||||||||||
|
|
||||||||||||
| API router for managing versioned groups of models and datasets. | ||||||||||||
| Groups are stored as JSON files under the ``asset_groups/`` directory. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| from typing import Optional | ||||||||||||
|
|
||||||||||||
| from fastapi import APIRouter, HTTPException, Query | ||||||||||||
| from pydantic import BaseModel | ||||||||||||
|
|
||||||||||||
| from transformerlab.services import asset_version_service | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| router = APIRouter(prefix="/asset_versions", tags=["asset_versions"]) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Request / Response schemas ─────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class CreateVersionRequest(BaseModel): | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: CreateVersionRequest lacks validation on asset_type and group_name fieldsCreateVersionRequest lacks validation on asset_type and group_name fields. Any string is accepted by Pydantic. Add Literal type or field validator. View DetailsLocation: AnalysisCreateVersionRequest lacks validation on asset_type and group_name fields
How to reproducePOST /asset_versions/versions with body {"asset_type": "model", "group_name": "", "asset_id": "test"}. Empty group_name is accepted and creates a version with empty group.Patch Details-class CreateVersionRequest(BaseModel):
- asset_type: str # 'model' or 'dataset'
- group_name: str
+class CreateVersionRequest(BaseModel):
+ asset_type: Literal["model", "dataset"]
+ group_name: str = Field(..., min_length=1, max_length=255)AI Fix PromptTip: Reply with |
||||||||||||
| asset_type: str # 'model' or 'dataset' | ||||||||||||
| group_name: str | ||||||||||||
| asset_id: str | ||||||||||||
| version_label: str = "v1" | ||||||||||||
| job_id: Optional[str] = None | ||||||||||||
| description: Optional[str] = None | ||||||||||||
| title: Optional[str] = None | ||||||||||||
| long_description: Optional[str] = None | ||||||||||||
| cover_image: Optional[str] = None | ||||||||||||
| evals: Optional[dict] = None | ||||||||||||
| extra_metadata: Optional[dict] = None | ||||||||||||
| tag: Optional[str] = "latest" | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class SetTagRequest(BaseModel): | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: SetTagRequest has no validation on tag value at schema levelSetTagRequest has no validation on tag value at schema level. Arbitrary strings bypass Pydantic, relying only on service-layer check. Add Literal constraint to the schema. View DetailsLocation: AnalysisSetTagRequest has no validation on tag value at schema level
How to reproducePUT /asset_versions/versions/model/group/1/tag with body {"tag": "invalid"}. Pydantic accepts it; only the service raises ValueError.Patch Details-class SetTagRequest(BaseModel):
- tag: str # 'latest', 'production', 'draft'
+class SetTagRequest(BaseModel):
+ tag: Literal["latest", "production", "draft"]AI Fix PromptTip: Reply with |
||||||||||||
| tag: str | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class UpdateVersionRequest(BaseModel): | ||||||||||||
| description: Optional[str] = None | ||||||||||||
| title: Optional[str] = None | ||||||||||||
| long_description: Optional[str] = None | ||||||||||||
| cover_image: Optional[str] = None | ||||||||||||
| evals: Optional[dict] = None | ||||||||||||
| extra_metadata: Optional[dict] = None | ||||||||||||
| tag: Optional[str] = None | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Group endpoints ───────────────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.get("/groups", summary="List all version groups for a given asset type.") | ||||||||||||
| async def list_groups(asset_type: str = Query(..., description="'model' or 'dataset'")): | ||||||||||||
| try: | ||||||||||||
| return await asset_version_service.list_groups(asset_type) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.delete( | ||||||||||||
| "/groups/{asset_type}/{group_name}", | ||||||||||||
| summary="Delete all versions in a group.", | ||||||||||||
| ) | ||||||||||||
| async def delete_group(asset_type: str, group_name: str): | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Security: delete_group and delete_version have no authorization checksdelete_group and delete_version have no authorization checks. Any authenticated user can delete any group's versions. Add ownership or role-based access control. View DetailsLocation: Analysisdelete_group and delete_version have no authorization checks
How to reproduceAs any authenticated user, call DELETE /asset_versions/groups/model/some_group. The group is deleted regardless of who created it.AI Fix PromptTip: Reply with |
||||||||||||
| try: | ||||||||||||
| count = await asset_version_service.delete_group(asset_type, group_name) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| return {"status": "success", "deleted_count": count} | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Version CRUD ───────────────────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.post("/versions", summary="Create a new version in a group.") | ||||||||||||
| async def create_version(body: CreateVersionRequest): | ||||||||||||
| try: | ||||||||||||
| result = await asset_version_service.create_version( | ||||||||||||
| asset_type=body.asset_type, | ||||||||||||
| group_name=body.group_name, | ||||||||||||
| asset_id=body.asset_id, | ||||||||||||
| version_label=body.version_label, | ||||||||||||
| job_id=body.job_id, | ||||||||||||
| description=body.description, | ||||||||||||
| title=body.title, | ||||||||||||
| long_description=body.long_description, | ||||||||||||
| cover_image=body.cover_image, | ||||||||||||
| evals=body.evals, | ||||||||||||
| extra_metadata=body.extra_metadata, | ||||||||||||
| tag=body.tag, | ||||||||||||
| ) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.get( | ||||||||||||
| "/versions/{asset_type}/{group_name}", | ||||||||||||
| summary="List all versions in a group.", | ||||||||||||
| ) | ||||||||||||
| async def list_versions(asset_type: str, group_name: str): | ||||||||||||
| try: | ||||||||||||
| return await asset_version_service.list_versions(asset_type, group_name) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.get( | ||||||||||||
| "/versions/{asset_type}/{group_name}/{version_label}", | ||||||||||||
| summary="Get a specific version by its label.", | ||||||||||||
| ) | ||||||||||||
| async def get_version(asset_type: str, group_name: str, version_label: str): | ||||||||||||
| try: | ||||||||||||
| result = await asset_version_service.get_version(asset_type, group_name, version_label) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| if result is None: | ||||||||||||
| raise HTTPException(status_code=404, detail="Version not found") | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.delete( | ||||||||||||
| "/versions/{asset_type}/{group_name}/{version_label}", | ||||||||||||
| summary="Delete a specific version.", | ||||||||||||
| ) | ||||||||||||
| async def delete_version(asset_type: str, group_name: str, version_label: str): | ||||||||||||
| try: | ||||||||||||
| deleted = await asset_version_service.delete_version(asset_type, group_name, version_label) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| if not deleted: | ||||||||||||
| raise HTTPException(status_code=404, detail="Version not found") | ||||||||||||
| return {"status": "success"} | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Version update ────────────────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.patch( | ||||||||||||
| "/versions/{asset_type}/{group_name}/{version_label}", | ||||||||||||
| summary="Update metadata or tag on a specific version.", | ||||||||||||
| ) | ||||||||||||
| async def update_version(asset_type: str, group_name: str, version_label: str, body: UpdateVersionRequest): | ||||||||||||
| # Build kwargs only for fields the caller actually sent (present in the JSON body). | ||||||||||||
| # This lets the service layer distinguish "not provided" from "set to null". | ||||||||||||
| raw = body.model_dump(exclude_unset=True) | ||||||||||||
|
|
||||||||||||
| # Map body fields to service kwargs using the sentinel pattern | ||||||||||||
| kwargs = {} | ||||||||||||
| for field in ( | ||||||||||||
| "description", | ||||||||||||
| "title", | ||||||||||||
| "long_description", | ||||||||||||
| "cover_image", | ||||||||||||
| "evals", | ||||||||||||
| "extra_metadata", | ||||||||||||
| "tag", | ||||||||||||
| ): | ||||||||||||
| if field in raw: | ||||||||||||
| kwargs[field] = raw[field] | ||||||||||||
|
|
||||||||||||
| try: | ||||||||||||
| result = await asset_version_service.update_version(asset_type, group_name, version_label, **kwargs) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| if result is None: | ||||||||||||
| raise HTTPException(status_code=404, detail="Version not found") | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Tag management ────────────────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.put( | ||||||||||||
| "/versions/{asset_type}/{group_name}/{version_label}/tag", | ||||||||||||
| summary="Set a tag on a specific version. Moves the tag from any other version in the group.", | ||||||||||||
| ) | ||||||||||||
| async def set_tag(asset_type: str, group_name: str, version_label: str, body: SetTagRequest): | ||||||||||||
| try: | ||||||||||||
| result = await asset_version_service.set_tag(asset_type, group_name, version_label, body.tag) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| if result is None: | ||||||||||||
| raise HTTPException(status_code=404, detail="Version not found") | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.delete( | ||||||||||||
| "/versions/{asset_type}/{group_name}/{version_label}/tag", | ||||||||||||
| summary="Clear the tag from a specific version.", | ||||||||||||
| ) | ||||||||||||
| async def clear_tag(asset_type: str, group_name: str, version_label: str): | ||||||||||||
| try: | ||||||||||||
| result = await asset_version_service.clear_tag(asset_type, group_name, version_label) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| if result is None: | ||||||||||||
| raise HTTPException(status_code=404, detail="Version not found") | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Resolution ────────────────────────────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.get( | ||||||||||||
| "/resolve/{asset_type}/{group_name}", | ||||||||||||
| summary="Resolve a group to a specific version. Defaults to 'latest' tag.", | ||||||||||||
| ) | ||||||||||||
| async def resolve( | ||||||||||||
| asset_type: str, | ||||||||||||
| group_name: str, | ||||||||||||
| tag: Optional[str] = Query(None, description="Tag to resolve (any string, e.g. 'latest', 'production')"), | ||||||||||||
| version_label: Optional[str] = Query(None, description="Exact version label to resolve"), | ||||||||||||
| ): | ||||||||||||
| try: | ||||||||||||
| result = await asset_version_service.resolve(asset_type, group_name, tag=tag, version_label=version_label) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
| if result is None: | ||||||||||||
| raise HTTPException(status_code=404, detail="No matching version found") | ||||||||||||
| return result | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # ─── Bulk lookups (used by list views) ──────────────────────────────────────── | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @router.get( | ||||||||||||
| "/map/{asset_type}", | ||||||||||||
| summary="Get a map of asset_id -> group memberships for annotating list views.", | ||||||||||||
| ) | ||||||||||||
| async def get_asset_group_map(asset_type: str): | ||||||||||||
| try: | ||||||||||||
| return await asset_version_service.get_all_asset_group_map(asset_type) | ||||||||||||
| except ValueError as e: | ||||||||||||
| raise HTTPException(status_code=400, detail=str(e)) | ||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding this here what we discussed on Discord.