@@ -113,16 +113,24 @@ async def get_dataset_plan() -> DatasetPlanResponse:
113113
114114 # Get model_family from active config
115115 model_family = None
116+ model_flavour = None
116117 try :
117118 from simpletuner .simpletuner_sdk .server .services .configs_service import ConfigsService
118119
119120 configs_service = ConfigsService ()
120121 active_config = configs_service .get_active_config ()
121- model_family = active_config ["config" ].get ("model_family" ) or active_config ["config" ].get ("--model_family" )
122+ config_blob = active_config ["config" ]
123+ model_family = config_blob .get ("model_family" ) or config_blob .get ("--model_family" )
124+ model_flavour = config_blob .get ("model_flavour" ) or config_blob .get ("--model_flavour" )
122125 except Exception :
123126 pass
124127
125- validations = compute_validations (datasets , get_dataset_blueprints (), model_family = model_family )
128+ validations = compute_validations (
129+ datasets ,
130+ get_dataset_blueprints (),
131+ model_family = model_family ,
132+ model_flavour = model_flavour ,
133+ )
126134 return DatasetPlanResponse (
127135 datasets = datasets ,
128136 validations = validations ,
@@ -149,16 +157,24 @@ def _persist_plan(payload: DatasetPlanPayload) -> DatasetPlanResponse:
149157
150158 # Get model_family from active config
151159 model_family = None
160+ model_flavour = None
152161 try :
153162 from simpletuner .simpletuner_sdk .server .services .configs_service import ConfigsService
154163
155164 configs_service = ConfigsService ()
156165 active_config = configs_service .get_active_config ()
157- model_family = active_config ["config" ].get ("model_family" ) or active_config ["config" ].get ("--model_family" )
166+ config_blob = active_config ["config" ]
167+ model_family = config_blob .get ("model_family" ) or config_blob .get ("--model_family" )
168+ model_flavour = config_blob .get ("model_flavour" ) or config_blob .get ("--model_flavour" )
158169 except Exception :
159170 pass
160171
161- validations = compute_validations (datasets , get_dataset_blueprints (), model_family = model_family )
172+ validations = compute_validations (
173+ datasets ,
174+ get_dataset_blueprints (),
175+ model_family = model_family ,
176+ model_flavour = model_flavour ,
177+ )
162178 errors = [message for message in validations if message .level == "error" ]
163179 if errors :
164180 raise HTTPException (
@@ -460,17 +476,25 @@ async def create_dataset(dataset: Dict[str, Any]) -> Dict[str, Any]:
460476
461477 # Get model_family from active config
462478 model_family = None
479+ model_flavour = None
463480 try :
464481 from simpletuner .simpletuner_sdk .server .services .configs_service import ConfigsService
465482
466483 configs_service = ConfigsService ()
467484 active_config = configs_service .get_active_config ()
468- model_family = active_config ["config" ].get ("model_family" ) or active_config ["config" ].get ("--model_family" )
485+ config_blob = active_config ["config" ]
486+ model_family = config_blob .get ("model_family" ) or config_blob .get ("--model_family" )
487+ model_flavour = config_blob .get ("model_flavour" ) or config_blob .get ("--model_flavour" )
469488 except Exception :
470489 pass
471490
472491 # Validate the updated plan
473- validations = compute_validations (datasets , get_dataset_blueprints (), model_family = model_family )
492+ validations = compute_validations (
493+ datasets ,
494+ get_dataset_blueprints (),
495+ model_family = model_family ,
496+ model_flavour = model_flavour ,
497+ )
474498 errors = [v for v in validations if v .level == "error" ]
475499
476500 if errors :
@@ -512,17 +536,25 @@ async def update_dataset(dataset_id: str, dataset: Dict[str, Any]) -> Dict[str,
512536
513537 # Get model_family from active config
514538 model_family = None
539+ model_flavour = None
515540 try :
516541 from simpletuner .simpletuner_sdk .server .services .configs_service import ConfigsService
517542
518543 configs_service = ConfigsService ()
519544 active_config = configs_service .get_active_config ()
520- model_family = active_config .get ("model_family" ) or active_config .get ("--model_family" )
545+ config_blob = active_config ["config" ]
546+ model_family = config_blob .get ("model_family" ) or config_blob .get ("--model_family" )
547+ model_flavour = config_blob .get ("model_flavour" ) or config_blob .get ("--model_flavour" )
521548 except Exception :
522549 pass
523550
524551 # Validate the updated plan
525- validations = compute_validations (datasets , get_dataset_blueprints (), model_family = model_family )
552+ validations = compute_validations (
553+ datasets ,
554+ get_dataset_blueprints (),
555+ model_family = model_family ,
556+ model_flavour = model_flavour ,
557+ )
526558 errors = [v for v in validations if v .level == "error" ]
527559
528560 if errors :
0 commit comments