Skip to content

Commit b250a53

Browse files
author
bghira
committed
add a prompt library creator/editor
1 parent d6ba84f commit b250a53

File tree

10 files changed

+1006
-14
lines changed

10 files changed

+1006
-14
lines changed

simpletuner/simpletuner_sdk/server/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def _include_router_if_present(router: object) -> None:
267267
from .routes.fields import router as fields_router
268268
from .routes.hardware import router as hardware_router
269269
from .routes.models import router as models_router
270+
from .routes.prompt_libraries import router as prompt_libraries_router
270271
from .routes.publishing import router as publishing_router
271272
from .routes.system import router as system_router
272273
from .routes.training import router as training_router
@@ -281,6 +282,7 @@ def _include_router_if_present(router: object) -> None:
281282
caption_filters_router,
282283
checkpoints_router,
283284
configs_router,
285+
prompt_libraries_router,
284286
validation_router,
285287
training_router,
286288
web_router,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Routes for managing validation prompt libraries."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import asdict
6+
from typing import Dict, Optional
7+
8+
from fastapi import APIRouter, HTTPException, status
9+
from pydantic import BaseModel
10+
11+
from simpletuner.simpletuner_sdk.server.services.prompt_library_service import PromptLibraryError, PromptLibraryService
12+
13+
router = APIRouter(prefix="/api/prompt-libraries", tags=["prompt_libraries"])
14+
15+
16+
class PromptLibraryPayload(BaseModel):
17+
entries: Dict[str, str]
18+
previous_filename: Optional[str] = None
19+
20+
21+
def _call_service(func, *args, **kwargs):
22+
try:
23+
return func(*args, **kwargs)
24+
except PromptLibraryError as exc:
25+
raise HTTPException(status_code=exc.status_code, detail=exc.message)
26+
except Exception as exc:
27+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc))
28+
29+
30+
def _get_service() -> PromptLibraryService:
31+
return PromptLibraryService()
32+
33+
34+
@router.get("/")
35+
async def list_prompt_libraries() -> Dict[str, object]:
36+
service = _get_service()
37+
records = _call_service(service.list_libraries)
38+
return {"libraries": [asdict(record) for record in records], "count": len(records)}
39+
40+
41+
@router.get("/{filename}")
42+
async def get_prompt_library(filename: str) -> Dict[str, object]:
43+
service = _get_service()
44+
result = _call_service(service.read_library, filename)
45+
return {"entries": result["entries"], "library": asdict(result["library"])}
46+
47+
48+
@router.put("/{filename}")
49+
async def save_prompt_library(filename: str, payload: PromptLibraryPayload) -> Dict[str, object]:
50+
service = _get_service()
51+
result = _call_service(service.save_library, filename, payload.entries, payload.previous_filename)
52+
return {"entries": result["entries"], "library": asdict(result["library"])}

simpletuner/simpletuner_sdk/server/services/field_registry/sections/validation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def register_validation_fields(registry: "FieldRegistry") -> None:
516516
tooltip="See user_prompt_library.json.example for format",
517517
importance=ImportanceLevel.ADVANCED,
518518
order=4,
519+
custom_component="prompt_library_path",
519520
)
520521
)
521522

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""Service managing validation prompt library files."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
import logging
7+
import re
8+
from dataclasses import dataclass
9+
from datetime import datetime, timezone
10+
from pathlib import Path
11+
from typing import Any, Dict, List, Optional
12+
13+
from fastapi import status
14+
15+
from simpletuner.simpletuner_sdk.server.services.webui_state import WebUIStateStore
16+
from simpletuner.simpletuner_sdk.server.utils.paths import get_config_directory
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class PromptLibraryError(Exception):
22+
"""Domain error for prompt library operations."""
23+
24+
def __init__(self, message: str, status_code: int = status.HTTP_400_BAD_REQUEST) -> None:
25+
super().__init__(message)
26+
self.message = message
27+
self.status_code = status_code
28+
29+
30+
@dataclass
31+
class PromptLibraryRecord:
32+
filename: str
33+
relative_path: str
34+
absolute_path: str
35+
display_name: str
36+
library_name: str
37+
prompt_count: int
38+
updated_at: str
39+
40+
41+
class PromptLibraryService:
42+
"""Manages prompt library files stored under a configs directory."""
43+
44+
_PREFIX = "user_prompt_library"
45+
_LIBRARIES_SUBDIR = "validation_prompt_libraries"
46+
_FILENAME_PATTERN = re.compile(rf"^{_PREFIX}(?:-([A-Za-z0-9._-]+))?\.json$")
47+
48+
def __init__(self, config_dir: Optional[Path] = None, libraries_dir: Optional[Path] = None) -> None:
49+
self._config_dir = Path(config_dir) if config_dir else self._resolve_config_dir()
50+
self._config_dir.mkdir(parents=True, exist_ok=True)
51+
if libraries_dir:
52+
self._libraries_dir = Path(libraries_dir)
53+
else:
54+
self._libraries_dir = self._config_dir / self._LIBRARIES_SUBDIR
55+
self._libraries_dir.mkdir(parents=True, exist_ok=True)
56+
57+
def _resolve_config_dir(self) -> Path:
58+
try:
59+
defaults = WebUIStateStore().load_defaults()
60+
if defaults.configs_dir:
61+
candidate = Path(defaults.configs_dir).expanduser()
62+
candidate.mkdir(parents=True, exist_ok=True)
63+
return candidate
64+
except Exception as exc: # pragma: no cover - defensive guard
65+
logger.debug("Unable to resolve configs_dir from WebUI defaults", exc_info=exc)
66+
default = get_config_directory()
67+
default.mkdir(parents=True, exist_ok=True)
68+
return default
69+
70+
def _validate_filename(self, filename: str) -> str:
71+
if not filename:
72+
raise PromptLibraryError("Prompt library filename is required.")
73+
candidate = Path(filename).name
74+
match = self._FILENAME_PATTERN.fullmatch(candidate)
75+
if not match:
76+
raise PromptLibraryError(
77+
"Prompt library filenames must be user_prompt_library[-name].json and only contain letters, numbers, '.', '_', or '-'."
78+
)
79+
return candidate
80+
81+
def _normalize_entries(self, payload: Any) -> Dict[str, str]:
82+
if not isinstance(payload, dict):
83+
raise PromptLibraryError("Prompt library entries must be an object with ID -> prompt mappings.")
84+
normalized: Dict[str, str] = {}
85+
for key, value in payload.items():
86+
shortname = str(key).strip()
87+
if not shortname:
88+
continue
89+
prompt = "" if value is None else str(value)
90+
normalized[shortname] = prompt
91+
return normalized
92+
93+
def _load_entries(self, path: Path) -> Dict[str, str]:
94+
if not path.exists():
95+
raise PromptLibraryError(f"Prompt library '{path.name}' not found", status.HTTP_404_NOT_FOUND)
96+
try:
97+
with path.open("r", encoding="utf-8") as handle:
98+
payload = json.load(handle)
99+
except json.JSONDecodeError as exc:
100+
raise PromptLibraryError(f"Invalid JSON in '{path.name}': {exc}", status.HTTP_422_UNPROCESSABLE_CONTENT) from exc
101+
except OSError as exc:
102+
raise PromptLibraryError(f"Failed to read '{path.name}': {exc}", status.HTTP_500_INTERNAL_SERVER_ERROR) from exc
103+
return self._normalize_entries(payload)
104+
105+
def _build_metadata(self, path: Path, entries: Dict[str, str]) -> PromptLibraryRecord:
106+
match = self._FILENAME_PATTERN.fullmatch(path.name)
107+
library_name = match.group(1) if match else ""
108+
relative_path = self._relative_path(path)
109+
display_name = library_name if library_name else "default"
110+
absolute_path = str(path.resolve())
111+
updated_at = datetime.fromtimestamp(path.stat().st_mtime, tz=timezone.utc).isoformat()
112+
return PromptLibraryRecord(
113+
filename=path.name,
114+
relative_path=relative_path,
115+
absolute_path=absolute_path,
116+
display_name=display_name,
117+
library_name=library_name or "",
118+
prompt_count=len(entries),
119+
updated_at=updated_at,
120+
)
121+
122+
def _relative_path(self, path: Path) -> str:
123+
try:
124+
return path.relative_to(self._config_dir).as_posix()
125+
except Exception:
126+
return path.name
127+
128+
def list_libraries(self) -> List[PromptLibraryRecord]:
129+
records: List[PromptLibraryRecord] = []
130+
if not self._libraries_dir.exists():
131+
return records
132+
for path in sorted(self._libraries_dir.iterdir()):
133+
if not path.is_file():
134+
continue
135+
if not self._FILENAME_PATTERN.fullmatch(path.name):
136+
continue
137+
try:
138+
entries = self._load_entries(path)
139+
except PromptLibraryError as exc:
140+
logger.warning("Skipping prompt library '%s': %s", path.name, exc.message)
141+
continue
142+
records.append(self._build_metadata(path, entries))
143+
records.sort(key=lambda record: record.display_name.lower())
144+
return records
145+
146+
def read_library(self, filename: str) -> Dict[str, Any]:
147+
sanitized = self._validate_filename(filename)
148+
path = self._libraries_dir / sanitized
149+
entries = self._load_entries(path)
150+
metadata = self._build_metadata(path, entries)
151+
return {"entries": entries, "library": metadata}
152+
153+
def save_library(
154+
self,
155+
filename: str,
156+
entries: Dict[str, str],
157+
previous_filename: Optional[str] = None,
158+
) -> Dict[str, Any]:
159+
normalized = self._normalize_entries(entries)
160+
sanitized = self._validate_filename(filename)
161+
target = self._libraries_dir / sanitized
162+
target.parent.mkdir(parents=True, exist_ok=True)
163+
try:
164+
with target.open("w", encoding="utf-8") as handle:
165+
json.dump(normalized, handle, indent=2, ensure_ascii=False)
166+
handle.write("\n")
167+
except OSError as exc:
168+
raise PromptLibraryError(
169+
f"Failed to write '{target.name}': {exc}", status.HTTP_500_INTERNAL_SERVER_ERROR
170+
) from exc
171+
172+
if previous_filename:
173+
previous = self._validate_filename(previous_filename)
174+
if previous != sanitized:
175+
previous_path = self._libraries_dir / previous
176+
try:
177+
if previous_path.exists():
178+
previous_path.unlink()
179+
except OSError as exc:
180+
logger.warning("Could not remove old prompt library '%s': %s", previous, exc)
181+
182+
metadata = self._build_metadata(target, normalized)
183+
return {"entries": normalized, "library": metadata}

0 commit comments

Comments
 (0)