Skip to content

Commit a3ecc2d

Browse files
authored
Merge pull request #1841 from bghira/issue/1811
(#1811) disable prompt management options when i2v is selected
2 parents 99a1edc + 6553b14 commit a3ecc2d

File tree

10 files changed

+316
-17
lines changed

10 files changed

+316
-17
lines changed

simpletuner/helpers/models/common.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ class ModelFoundation(ABC):
151151
MAXIMUM_CANVAS_SIZE = None
152152
SUPPORTS_LORA = None
153153
SUPPORTS_CONTROLNET = None
154+
STRICT_I2V_FLAVOURS = tuple()
155+
STRICT_I2V_FOR_ALL_FLAVOURS = False
154156

155157
def __init__(self, config: dict, accelerator):
156158
self.config = config
@@ -181,6 +183,49 @@ def supports_controlnet(cls) -> bool:
181183
return bool(cls.SUPPORTS_CONTROLNET)
182184
return False
183185

186+
@classmethod
187+
def strict_i2v_flavours(cls):
188+
"""
189+
Return flavour identifiers that require strict image-to-video validation inputs.
190+
"""
191+
flavours = getattr(cls, "STRICT_I2V_FLAVOURS", tuple())
192+
if not flavours:
193+
return []
194+
if isinstance(flavours, (list, tuple, set, frozenset)):
195+
result = []
196+
for entry in flavours:
197+
try:
198+
value = str(entry).strip()
199+
except Exception:
200+
continue
201+
if value:
202+
result.append(value)
203+
return list(dict.fromkeys(result))
204+
if isinstance(flavours, str):
205+
trimmed = flavours.strip()
206+
return [trimmed] if trimmed else []
207+
return []
208+
209+
@classmethod
210+
def is_strict_i2v_flavour(cls, flavour) -> bool:
211+
"""
212+
Determine whether the provided flavour requires strict image-to-video validation inputs.
213+
"""
214+
if getattr(cls, "STRICT_I2V_FOR_ALL_FLAVOURS", False):
215+
return True
216+
if not flavour:
217+
return False
218+
try:
219+
candidate = str(flavour).strip().lower()
220+
except Exception:
221+
return False
222+
if not candidate:
223+
return False
224+
for entry in cls.strict_i2v_flavours():
225+
if entry.strip().lower() == candidate:
226+
return True
227+
return False
228+
184229
def log_model_devices(self):
185230
"""
186231
Log the devices of the model components.

simpletuner/helpers/models/wan/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ class Wan(VideoModelFoundation):
316316
"ti2v-5b-2.2",
317317
}
318318
)
319+
STRICT_I2V_FLAVOURS = tuple(sorted((I2V_FLAVOURS | FLF2V_FLAVOURS)))
319320

320321
def __init__(self, config, accelerator):
321322
super().__init__(config, accelerator)

simpletuner/simpletuner_sdk/server/routes/datasets.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,24 @@ async def get_dataset_plan() -> DatasetPlanResponse:
113113

114114
# Get model_family from active config
115115
model_family = None
116+
model_flavour = None
116117
try:
117118
from simpletuner.simpletuner_sdk.server.services.configs_service import ConfigsService
118119

119120
configs_service = ConfigsService()
120121
active_config = configs_service.get_active_config()
121-
model_family = active_config["config"].get("model_family") or active_config["config"].get("--model_family")
122+
config_blob = active_config["config"]
123+
model_family = config_blob.get("model_family") or config_blob.get("--model_family")
124+
model_flavour = config_blob.get("model_flavour") or config_blob.get("--model_flavour")
122125
except Exception:
123126
pass
124127

125-
validations = compute_validations(datasets, get_dataset_blueprints(), model_family=model_family)
128+
validations = compute_validations(
129+
datasets,
130+
get_dataset_blueprints(),
131+
model_family=model_family,
132+
model_flavour=model_flavour,
133+
)
126134
return DatasetPlanResponse(
127135
datasets=datasets,
128136
validations=validations,
@@ -149,16 +157,24 @@ def _persist_plan(payload: DatasetPlanPayload) -> DatasetPlanResponse:
149157

150158
# Get model_family from active config
151159
model_family = None
160+
model_flavour = None
152161
try:
153162
from simpletuner.simpletuner_sdk.server.services.configs_service import ConfigsService
154163

155164
configs_service = ConfigsService()
156165
active_config = configs_service.get_active_config()
157-
model_family = active_config["config"].get("model_family") or active_config["config"].get("--model_family")
166+
config_blob = active_config["config"]
167+
model_family = config_blob.get("model_family") or config_blob.get("--model_family")
168+
model_flavour = config_blob.get("model_flavour") or config_blob.get("--model_flavour")
158169
except Exception:
159170
pass
160171

161-
validations = compute_validations(datasets, get_dataset_blueprints(), model_family=model_family)
172+
validations = compute_validations(
173+
datasets,
174+
get_dataset_blueprints(),
175+
model_family=model_family,
176+
model_flavour=model_flavour,
177+
)
162178
errors = [message for message in validations if message.level == "error"]
163179
if errors:
164180
raise HTTPException(
@@ -460,17 +476,25 @@ async def create_dataset(dataset: Dict[str, Any]) -> Dict[str, Any]:
460476

461477
# Get model_family from active config
462478
model_family = None
479+
model_flavour = None
463480
try:
464481
from simpletuner.simpletuner_sdk.server.services.configs_service import ConfigsService
465482

466483
configs_service = ConfigsService()
467484
active_config = configs_service.get_active_config()
468-
model_family = active_config["config"].get("model_family") or active_config["config"].get("--model_family")
485+
config_blob = active_config["config"]
486+
model_family = config_blob.get("model_family") or config_blob.get("--model_family")
487+
model_flavour = config_blob.get("model_flavour") or config_blob.get("--model_flavour")
469488
except Exception:
470489
pass
471490

472491
# Validate the updated plan
473-
validations = compute_validations(datasets, get_dataset_blueprints(), model_family=model_family)
492+
validations = compute_validations(
493+
datasets,
494+
get_dataset_blueprints(),
495+
model_family=model_family,
496+
model_flavour=model_flavour,
497+
)
474498
errors = [v for v in validations if v.level == "error"]
475499

476500
if errors:
@@ -512,17 +536,25 @@ async def update_dataset(dataset_id: str, dataset: Dict[str, Any]) -> Dict[str,
512536

513537
# Get model_family from active config
514538
model_family = None
539+
model_flavour = None
515540
try:
516541
from simpletuner.simpletuner_sdk.server.services.configs_service import ConfigsService
517542

518543
configs_service = ConfigsService()
519544
active_config = configs_service.get_active_config()
520-
model_family = active_config.get("model_family") or active_config.get("--model_family")
545+
config_blob = active_config["config"]
546+
model_family = config_blob.get("model_family") or config_blob.get("--model_family")
547+
model_flavour = config_blob.get("model_flavour") or config_blob.get("--model_flavour")
521548
except Exception:
522549
pass
523550

524551
# Validate the updated plan
525-
validations = compute_validations(datasets, get_dataset_blueprints(), model_family=model_family)
552+
validations = compute_validations(
553+
datasets,
554+
get_dataset_blueprints(),
555+
model_family=model_family,
556+
model_flavour=model_flavour,
557+
)
526558
errors = [v for v in validations if v.level == "error"]
527559

528560
if errors:

simpletuner/simpletuner_sdk/server/services/dataset_plan.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def compute_validations(
8989
datasets: List[Dict[str, Any]],
9090
blueprints: Optional[List[BackendBlueprint]] = None,
9191
model_family: Optional[str] = None,
92+
model_flavour: Optional[str] = None,
9293
) -> List[ValidationMessage]:
9394
"""Perform lightweight validation mirroring the UI logic."""
9495
validations: List[ValidationMessage] = []
@@ -149,6 +150,7 @@ def compute_validations(
149150

150151
# Check if this is a video-only model
151152
is_video_model = False
153+
requires_strict_video_inputs = False
152154
if model_family:
153155
try:
154156
from simpletuner.helpers.models.common import VideoModelFoundation
@@ -157,6 +159,8 @@ def compute_validations(
157159
model_cls = ModelRegistry.get(model_family)
158160
if model_cls:
159161
is_video_model = issubclass(model_cls, VideoModelFoundation)
162+
if hasattr(model_cls, "is_strict_i2v_flavour") and callable(model_cls.is_strict_i2v_flavour):
163+
requires_strict_video_inputs = bool(model_cls.is_strict_i2v_flavour(model_flavour))
160164
except Exception:
161165
pass
162166

@@ -167,13 +171,22 @@ def compute_validations(
167171

168172
if is_video_model:
169173
if video_count == 0:
170-
validations.append(
171-
ValidationMessage(
172-
field="datasets",
173-
message="at least one video dataset is required for video models",
174-
level="error",
174+
if requires_strict_video_inputs:
175+
validations.append(
176+
ValidationMessage(
177+
field="datasets",
178+
message="strict image-to-video flavours require at least one video dataset",
179+
level="error",
180+
)
181+
)
182+
elif image_count == 0:
183+
validations.append(
184+
ValidationMessage(
185+
field="datasets",
186+
message="add at least one image or video dataset for this model",
187+
level="error",
188+
)
175189
)
176-
)
177190
else:
178191
if image_count == 0:
179192
validations.append(

simpletuner/simpletuner_sdk/server/services/models_service.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@ def get_model_details(self, model_family: str) -> Dict[str, Any]:
154154
),
155155
"is_video_model": issubclass(model_cls, VideoModelFoundation),
156156
}
157+
strict_i2v_flavours: list[str] = []
158+
try:
159+
if hasattr(model_cls, "strict_i2v_flavours") and callable(model_cls.strict_i2v_flavours):
160+
strict_i2v_flavours = list(model_cls.strict_i2v_flavours())
161+
except Exception:
162+
strict_i2v_flavours = []
163+
capabilities["strict_i2v_flavours"] = strict_i2v_flavours
164+
capabilities["strict_i2v_all_flavours"] = bool(getattr(model_cls, "STRICT_I2V_FOR_ALL_FLAVOURS", False))
157165

158166
default_flavour = getattr(model_cls, "DEFAULT_MODEL_FLAVOUR", None)
159167
if default_flavour is None:

simpletuner/static/css/trainer.css

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,3 +1237,42 @@ optgroup option {
12371237
.config-card:first-child::before {
12381238
display: none;
12391239
}
1240+
1241+
.prompt-overlay {
1242+
display: none;
1243+
position: absolute;
1244+
inset: 0;
1245+
background: rgba(12, 15, 25, 0.55);
1246+
border-radius: 0.75rem;
1247+
display: flex;
1248+
align-items: center;
1249+
justify-content: center;
1250+
padding: 1.25rem;
1251+
z-index: 2;
1252+
backdrop-filter: blur(0.4rem);
1253+
}
1254+
1255+
.form-section.prompt-management-disabled {
1256+
position: relative;
1257+
}
1258+
1259+
.form-section.prompt-management-disabled .prompt-overlay {
1260+
display: flex;
1261+
}
1262+
1263+
.form-section.prompt-management-disabled .prompt-overlay-content {
1264+
display: flex;
1265+
gap: 0.75rem;
1266+
color: #fff;
1267+
text-align: left;
1268+
}
1269+
1270+
.form-section.prompt-management-disabled .prompt-overlay-content strong {
1271+
display: block;
1272+
margin-bottom: 0.25rem;
1273+
}
1274+
1275+
.form-section.prompt-management-disabled .prompt-overlay-content i {
1276+
font-size: 1.5rem;
1277+
}
1278+
}

simpletuner/templates/components/action_buttons_htmx.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
data-bs-toggle="modal"
9898
data-bs-target="#shutdownConfirmModal">
9999
<i class="fas fa-power-off me-1"></i>
100-
<span class="btn-text">&#x23FB; Shutdown</span>
100+
<span class="btn-text">Shutdown</span>
101101
</button>
102102
</div>
103103
</div>

0 commit comments

Comments
 (0)