diff --git a/simpletuner/helpers/configuration/cmd_args.py b/simpletuner/helpers/configuration/cmd_args.py index c2569b6ac..93a7a6c56 100644 --- a/simpletuner/helpers/configuration/cmd_args.py +++ b/simpletuner/helpers/configuration/cmd_args.py @@ -837,6 +837,47 @@ def _set_tf32(enabled: bool) -> None: elif args.sana_complex_human_instruction == "None": args.sana_complex_human_instruction = None + if isinstance(getattr(args, "validation_adapter_path", None), str): + candidate = args.validation_adapter_path.strip() + args.validation_adapter_path = candidate or None + + if getattr(args, "validation_adapter_config", None): + args.validation_adapter_config = _parse_json_like_option( + args.validation_adapter_config, + "--validation_adapter_config", + ) + + if args.validation_adapter_path and args.validation_adapter_config: + raise ValueError("Provide either --validation_adapter_path or --validation_adapter_config, not both.") + + if isinstance(getattr(args, "validation_adapter_name", None), str): + candidate = args.validation_adapter_name.strip() + args.validation_adapter_name = candidate or None + + strength_value = getattr(args, "validation_adapter_strength", None) + if strength_value is None or strength_value in ("", "None"): + args.validation_adapter_strength = 1.0 + else: + try: + strength = float(strength_value) + except (TypeError, ValueError): + raise ValueError(f"Invalid --validation_adapter_strength value: {strength_value}") from None + if strength <= 0: + raise ValueError("--validation_adapter_strength must be greater than 0.") + args.validation_adapter_strength = strength + + mode_value = getattr(args, "validation_adapter_mode", None) + if mode_value in (None, "", "None"): + args.validation_adapter_mode = "adapter_only" + else: + normalized_mode = str(mode_value).strip().lower() + valid_modes = {"adapter_only", "comparison", "none"} + if normalized_mode not in valid_modes: + raise ValueError( + f"Invalid --validation_adapter_mode '{mode_value}'. Expected one of: {', '.join(sorted(valid_modes))}." + ) + args.validation_adapter_mode = normalized_mode + if args.attention_mechanism != "diffusers" and not torch.cuda.is_available(): warning_log("For non-CUDA systems, only Diffusers attention mechanism is officially supported.") diff --git a/simpletuner/helpers/models/chroma/pipeline.py b/simpletuner/helpers/models/chroma/pipeline.py index a49a6128f..75c8e07dd 100644 --- a/simpletuner/helpers/models/chroma/pipeline.py +++ b/simpletuner/helpers/models/chroma/pipeline.py @@ -45,6 +45,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast from simpletuner.helpers.models.chroma.transformer import ChromaTransformer2DModel +from simpletuner.helpers.utils.offloading import restore_offload_state, unpack_offload_state if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -423,7 +424,12 @@ def load_lora_into_controlnet( # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): @@ -447,7 +453,8 @@ def load_lora_into_controlnet( if warn_msg: logger.warning(warn_msg) - cls._optionally_enable_offloading(is_model_cpu_offload, is_sequential_cpu_offload, _pipeline) + # Offload back. + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) @classmethod def load_lora_into_transformer( @@ -527,7 +534,12 @@ def load_lora_into_transformer( # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): @@ -551,7 +563,8 @@ def load_lora_into_transformer( if warn_msg: logger.warning(warn_msg) - cls._optionally_enable_offloading(is_model_cpu_offload, is_sequential_cpu_offload, _pipeline) + # Offload back. + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder def load_lora_into_text_encoder( diff --git a/simpletuner/helpers/models/flux/pipeline.py b/simpletuner/helpers/models/flux/pipeline.py index e26ee33c0..9f0f702a2 100644 --- a/simpletuner/helpers/models/flux/pipeline.py +++ b/simpletuner/helpers/models/flux/pipeline.py @@ -66,6 +66,8 @@ T5TokenizerFast, ) +from simpletuner.helpers.utils.offloading import restore_offload_state, unpack_offload_state + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -521,7 +523,12 @@ def load_lora_into_controlnet( adapter_name = get_adapter_name(controlnet) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): @@ -534,10 +541,7 @@ def load_lora_into_controlnet( logger.info(f"Loaded ControlNet LoRA with incompatible keys: {incompatible_keys}") # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) @classmethod def load_lora_into_transformer( @@ -617,7 +621,12 @@ def load_lora_into_transformer( # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): @@ -652,10 +661,7 @@ def load_lora_into_transformer( logger.warning(warn_msg) # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) # Unsafe code /> @classmethod @@ -769,7 +775,12 @@ def load_lora_into_text_encoder( if adapter_name is None: adapter_name = get_adapter_name(text_encoder) - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not @@ -786,10 +797,8 @@ def load_lora_into_text_encoder( text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) + # Unsafe code /> @classmethod diff --git a/simpletuner/helpers/models/hidream/pipeline.py b/simpletuner/helpers/models/hidream/pipeline.py index 4b6e95d9c..b725214a2 100644 --- a/simpletuner/helpers/models/hidream/pipeline.py +++ b/simpletuner/helpers/models/hidream/pipeline.py @@ -46,6 +46,7 @@ ) from simpletuner.helpers.models.hidream.schedule import FlowUniPCMultistepScheduler +from simpletuner.helpers.utils.offloading import restore_offload_state, unpack_offload_state if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -486,7 +487,12 @@ def load_lora_into_controlnet( adapter_name = get_adapter_name(controlnet) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): @@ -499,10 +505,7 @@ def load_lora_into_controlnet( logger.info(f"Loaded ControlNet LoRA with incompatible keys: {incompatible_keys}") # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) @classmethod def load_lora_into_transformer( @@ -638,7 +641,12 @@ def load_lora_into_text_encoder( if adapter_name is None: adapter_name = get_adapter_name(text_encoder) - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not @@ -655,10 +663,7 @@ def load_lora_into_text_encoder( text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) @classmethod def save_lora_weights( diff --git a/simpletuner/helpers/models/sd3/pipeline.py b/simpletuner/helpers/models/sd3/pipeline.py index a3509842d..6f9639ede 100644 --- a/simpletuner/helpers/models/sd3/pipeline.py +++ b/simpletuner/helpers/models/sd3/pipeline.py @@ -53,6 +53,8 @@ from huggingface_hub.utils import validate_hf_hub_args from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from simpletuner.helpers.utils.offloading import restore_offload_state, unpack_offload_state + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -498,7 +500,12 @@ def load_lora_into_controlnet( adapter_name = get_adapter_name(controlnet) # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): @@ -511,10 +518,7 @@ def load_lora_into_controlnet( logger.info(f"Loaded ControlNet LoRA with incompatible keys: {incompatible_keys}") # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) @classmethod def load_lora_into_transformer( @@ -584,7 +588,12 @@ def load_lora_into_transformer( # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) peft_kwargs = {} if is_peft_version(">=", "0.13.1"): @@ -619,10 +628,7 @@ def load_lora_into_transformer( logger.warning(warn_msg) # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) # Unsafe code /> @classmethod @@ -736,7 +742,12 @@ def load_lora_into_text_encoder( if adapter_name is None: adapter_name = get_adapter_name(text_encoder) - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + offload_state = cls._optionally_disable_offloading(_pipeline) + ( + is_model_cpu_offload, + is_sequential_cpu_offload, + is_group_offload, + ) = unpack_offload_state(offload_state) # inject LoRA layers and load the state dict # in transformers we automatically check whether the adapter name is already in use or not @@ -753,10 +764,7 @@ def load_lora_into_text_encoder( text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() + restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload) # Unsafe code /> @classmethod diff --git a/simpletuner/helpers/training/validation.py b/simpletuner/helpers/training/validation.py index 4766bb0af..6d87707ed 100644 --- a/simpletuner/helpers/training/validation.py +++ b/simpletuner/helpers/training/validation.py @@ -3,6 +3,7 @@ import logging import os import sys +from contextlib import contextmanager from io import BytesIO from typing import Union @@ -39,6 +40,11 @@ from simpletuner.helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager, prepare_model_for_deepspeed from simpletuner.helpers.training.exceptions import MultiDatasetExhausted from simpletuner.helpers.training.state_tracker import StateTracker +from simpletuner.helpers.training.validation_adapters import ( + ValidationAdapterRun, + ValidationAdapterSpec, + build_validation_adapter_runs, +) logger = logging.getLogger("Validation") from simpletuner.helpers.training.multi_process import should_log @@ -755,6 +761,14 @@ def __init__( logger.debug(f"Using model evaluator: {self.model_evaluator}") self._update_state() self.eval_scores = {} + self.validation_adapter_runs = build_validation_adapter_runs( + getattr(args, "validation_adapter_path", None), + getattr(args, "validation_adapter_config", None), + adapter_name=getattr(args, "validation_adapter_name", None), + adapter_strength=float(getattr(args, "validation_adapter_strength", 1.0) or 1.0), + adapter_mode=getattr(args, "validation_adapter_mode", None), + ) + self._active_adapter_run: ValidationAdapterRun | None = None def _validation_seed_source(self): if self.config.validation_seed_source == "gpu": @@ -1131,7 +1145,19 @@ def run_validations( self.validation_images = None return self self.setup_scheduler() - self.process_prompts(validation_type=validation_type) + master_validation_images: dict = {} + self.validation_prompt_dict = {} + self.validation_video_paths.clear() + self.eval_scores = {} + for adapter_run in self.validation_adapter_runs: + self._log_adapter_run(adapter_run) + with self._temporary_validation_adapters(adapter_run): + self.process_prompts( + validation_type=validation_type, + adapter_run=adapter_run, + image_accumulator=master_validation_images, + ) + self.validation_images = master_validation_images self.finalize_validation(validation_type) if self.evaluation_result is not None: logger.info(f"Evaluation result: {self.evaluation_result}") @@ -1298,15 +1324,144 @@ def clean_pipeline(self): del self.model.pipeline self.model.pipeline = None - def process_prompts(self, validation_type: str = None): + def _has_adapter_variants(self) -> bool: + return any(run.adapters for run in self.validation_adapter_runs if not run.is_base) + + def _decorate_shortname(self, shortname: str, adapter_run: ValidationAdapterRun | None) -> str: + if adapter_run is None or adapter_run.is_base: + return shortname + suffix = adapter_run.slug + if not shortname: + return suffix + return f"{shortname}__{suffix}" + + def _log_adapter_run(self, adapter_run: ValidationAdapterRun): + if adapter_run.is_base and not self._has_adapter_variants(): + return + if adapter_run.is_base: + logger.info("Running validation without additional adapters.") + return + logger.info( + "Running validation with adapter set '%s' containing %d adapter(s).", + adapter_run.label, + len(adapter_run.adapters), + ) + + def _next_adapter_name( + self, adapter_run: ValidationAdapterRun, adapter_spec: ValidationAdapterSpec, idx: int, existing: list[str] + ) -> str: + base_name = adapter_spec.adapter_name or (adapter_run.slug or f"validation_adapter_{idx}") + candidate = base_name.strip() or f"validation_adapter_{idx}" + suffix = 2 + while candidate in existing: + candidate = f"{base_name}_{suffix}" + suffix += 1 + return candidate + + @contextmanager + def _temporary_validation_adapters(self, adapter_run: ValidationAdapterRun): + if adapter_run is None or not adapter_run.adapters: + yield + return + pipeline = getattr(self.model, "pipeline", None) + if pipeline is None: + yield + return + if not hasattr(pipeline, "load_lora_weights"): + raise ValueError( + "The current pipeline does not support loading LoRA adapters. " + "Remove --validation_adapter_path/--validation_adapter_config to continue." + ) + adapter_names: list[str] = [] + adapter_scales: list[float] = [] + for idx, adapter in enumerate(adapter_run.adapters): + adapter_name = self._next_adapter_name(adapter_run, adapter, idx, adapter_names) + load_kwargs = {"adapter_name": adapter_name} + if adapter.weight_name: + load_kwargs["weight_name"] = adapter.weight_name + try: + if adapter.is_local: + pipeline.load_lora_weights(adapter.location, **load_kwargs) + else: + pipeline.load_lora_weights(adapter.repo_id, **load_kwargs) + except Exception as exc: # pragma: no cover - defensive log + logger.error("Failed to load validation adapter '%s': %s", adapter.location, exc) + raise + adapter_names.append(adapter_name) + adapter_scales.append(adapter.strength) + self._set_validation_adapter_weights(pipeline, adapter_names, adapter_scales) + try: + yield + finally: + self._remove_validation_adapters(pipeline, adapter_names) + + def _set_validation_adapter_weights(self, pipeline, adapter_names: list[str], adapter_scales: list[float]): + if not adapter_names: + return + if hasattr(pipeline, "set_adapters"): + names = adapter_names if len(adapter_names) > 1 else adapter_names[0] + scales = adapter_scales if len(adapter_scales) > 1 else adapter_scales[0] + pipeline.set_adapters(names, scales) + elif hasattr(pipeline, "set_adapter"): + pipeline.set_adapter(adapter_names[0], adapter_scales[0]) + else: + logger.warning("Pipeline does not expose set_adapters; using adapter defaults.") + + def _remove_validation_adapters(self, pipeline, adapter_names: list[str]): + if not adapter_names: + return + if hasattr(pipeline, "delete_adapters"): + names = adapter_names if len(adapter_names) > 1 else adapter_names[0] + pipeline.delete_adapters(names) + else: + logger.warning("Could not delete temporary validation adapters: %s", adapter_names) + self._assert_adapters_detached(pipeline, adapter_names) + + def _assert_adapters_detached(self, pipeline, adapter_names: list[str]): + if pipeline is None or not adapter_names: + return + lingering: set[str] = set() + components = [] + if hasattr(pipeline, "components") and isinstance(pipeline.components, dict): + components.extend(pipeline.components.values()) + for attr in ( + "transformer", + "text_encoder", + "text_encoder_2", + "text_encoder_3", + "text_encoder_4", + "controlnet", + "unet", + ): + component = getattr(pipeline, attr, None) + if component is not None: + components.append(component) + for module in components: + config = getattr(module, "peft_config", None) + if not isinstance(config, dict): + continue + for name in adapter_names: + if name in config: + lingering.add(name) + if lingering: + raise RuntimeError( + f"Failed to detach temporary validation adapters: {', '.join(sorted(lingering))}. " + "Please ensure your pipeline supports adapter removal." + ) + + def process_prompts( + self, + validation_type: str = None, + adapter_run: ValidationAdapterRun | None = None, + image_accumulator: dict | None = None, + ): """Processes each validation prompt and logs the result.""" - self.validation_prompt_dict = {} self.evaluation_result = None - validation_images = {} + if self.validation_prompt_dict is None: + self.validation_prompt_dict = {} + validation_images = image_accumulator if image_accumulator is not None else {} _content = self.validation_prompt_metadata["validation_prompts"] total_samples = len(_content) if _content is not None else 0 - self.eval_scores = {} - self.validation_video_paths.clear() if self.validation_image_inputs: # Override the pipeline inputs to be entirely based upon the validation image inputs. _content = self.validation_image_inputs @@ -1344,23 +1499,24 @@ def process_prompts(self, validation_type: str = None): ) = prompt else: shortname = self.validation_prompt_metadata["validation_shortnames"][idx] - logger.debug(f"validation prompt (shortname={shortname}): '{prompt}'") - self.validation_prompt_dict[shortname] = prompt + decorated_shortname = self._decorate_shortname(shortname, adapter_run) + logger.debug(f"validation prompt (shortname={decorated_shortname}): '{prompt}'") + self.validation_prompt_dict[decorated_shortname] = prompt logger.debug(f"Processing validation for prompt: {prompt}") ( stitched_validation_images, checkpoint_validation_images, ema_validation_images, - ) = self.validate_prompt(prompt, shortname, validation_input_image, validation_type) + ) = self.validate_prompt(prompt, decorated_shortname, validation_input_image, validation_type) validation_images.update(stitched_validation_images) if isinstance(self.model, VideoModelFoundation): - self._save_videos(validation_images, shortname, prompt) + self._save_videos(validation_images, decorated_shortname, prompt) else: - self._save_images(validation_images, shortname, prompt) + self._save_images(validation_images, decorated_shortname, prompt) logger.debug(f"Completed generating image: {prompt}") self.validation_images = validation_images self.evaluation_result = self.evaluate_images(checkpoint_validation_images) - self._log_validations_to_webhook(validation_images, shortname, prompt) + self._log_validations_to_webhook(validation_images, decorated_shortname, prompt) idx += 1 try: self._log_validations_to_trackers(validation_images) diff --git a/simpletuner/helpers/training/validation_adapters.py b/simpletuner/helpers/training/validation_adapters.py new file mode 100644 index 000000000..e4cec1535 --- /dev/null +++ b/simpletuner/helpers/training/validation_adapters.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import os +import re +from dataclasses import dataclass +from typing import Any, Iterable, List, Sequence, Tuple + +DEFAULT_LORA_WEIGHT_NAME = "pytorch_lora_weights.safetensors" +VALID_ADAPTER_MODES = {"adapter_only", "comparison", "none"} + + +def _slugify(value: str) -> str: + slug = re.sub(r"[^a-zA-Z0-9]+", "_", value.strip().lower()) + slug = slug.strip("_") + return slug or "adapter" + + +def _stem_from_path(path_value: str) -> str: + basename = os.path.basename(path_value.rstrip("/")) + stem, _ = os.path.splitext(basename) + return stem or basename or "adapter" + + +@dataclass(frozen=True) +class ValidationAdapterSpec: + """Represents a single adapter load instruction for validation sampling.""" + + is_local: bool + location: str + weight_name: str | None + strength: float + adapter_name: str | None = None + + @property + def repo_id(self) -> str | None: + return None if self.is_local else self.location + + @property + def path(self) -> str | None: + return self.location if self.is_local else None + + +@dataclass(frozen=True) +class ValidationAdapterRun: + """Represents one validation pass that may enable zero or more adapters.""" + + label: str | None + slug: str + adapters: Tuple[ValidationAdapterSpec, ...] + is_base: bool = False + + @classmethod + def base(cls) -> "ValidationAdapterRun": + return cls(label=None, slug="", adapters=tuple(), is_base=True) + + +def _extract_repo_and_weight(raw_value: str) -> Tuple[str, str]: + repo_id = raw_value + weight_name = DEFAULT_LORA_WEIGHT_NAME + if ":" in raw_value: + repo_id, weight_name = raw_value.split(":", 1) + return repo_id.strip(), weight_name.strip() or DEFAULT_LORA_WEIGHT_NAME + + +def _build_adapter_spec(raw_value: str, strength: float, adapter_name: str | None = None) -> ValidationAdapterSpec: + if raw_value is None: + raise ValueError("Adapter path cannot be None.") + cleaned = str(raw_value).strip() + if cleaned == "": + raise ValueError("Adapter path cannot be empty.") + if adapter_name is not None: + adapter_name = adapter_name.strip() or None + + expanded = os.path.abspath(os.path.expanduser(cleaned)) + path_exists = os.path.exists(expanded) + drive, _ = os.path.splitdrive(expanded) + if path_exists or drive: + return ValidationAdapterSpec( + is_local=True, + location=expanded, + weight_name=None, + strength=float(strength), + adapter_name=adapter_name, + ) + + repo_id, weight_name = _extract_repo_and_weight(cleaned) + return ValidationAdapterSpec( + is_local=False, + location=repo_id, + weight_name=weight_name, + strength=float(strength), + adapter_name=adapter_name, + ) + + +def _norm_strength(value: Any, default: float = 1.0) -> float: + if value is None: + return float(default) + try: + scale = float(value) + except (TypeError, ValueError): + raise ValueError(f"Invalid adapter scale value: {value}") from None + if scale <= 0: + raise ValueError("Adapter scale must be greater than zero.") + return scale + + +def _ensure_list(entry: Any, *, field_name: str) -> List[Any]: + if entry is None: + return [] + if isinstance(entry, list): + return entry + raise ValueError(f"Expected a list for '{field_name}', got {type(entry).__name__}.") + + +def _normalize_run_entry(entry: Any) -> Tuple[str | None, List[ValidationAdapterSpec]]: + if isinstance(entry, str): + spec = _build_adapter_spec(entry, 1.0) + return _stem_from_path(entry), [spec] + + if not isinstance(entry, dict): + raise ValueError(f"Invalid adapter config entry: {entry!r}") + + label = entry.get("label") or entry.get("name") + base_strength = _norm_strength(entry.get("strength", entry.get("scale")), 1.0) + base_adapter_name = entry.get("adapter_name") + + if "path" in entry and "paths" not in entry and "adapters" not in entry: + spec = _build_adapter_spec( + entry["path"], entry.get("strength", entry.get("scale", base_strength)), base_adapter_name + ) + return label or _stem_from_path(entry["path"]), [spec] + + adapter_entries = entry.get("adapters") + if adapter_entries is None: + adapter_entries = entry.get("paths") + adapter_entries = _ensure_list(adapter_entries, field_name="paths") + + specs: List[ValidationAdapterSpec] = [] + for adapter in adapter_entries: + if isinstance(adapter, str): + specs.append(_build_adapter_spec(adapter, base_strength, base_adapter_name)) + continue + if not isinstance(adapter, dict): + raise ValueError(f"Invalid adapter specification: {adapter!r}") + path_value = adapter.get("path") + if path_value is None: + raise ValueError(f"Adapter specification is missing 'path': {adapter!r}") + adapter_strength = _norm_strength(adapter.get("strength", adapter.get("scale")), base_strength) + adapter_name = adapter.get("adapter_name") or base_adapter_name + specs.append(_build_adapter_spec(path_value, adapter_strength, adapter_name)) + + if not specs: + raise ValueError("Adapter run must include at least one adapter path.") + + if label is None and specs: + first = specs[0] + label = _stem_from_path(first.path or first.repo_id or "adapter") + + return label, specs + + +def _iter_config_entries(config: Any) -> Iterable[Any]: + if config is None: + return [] + if isinstance(config, dict): + if "runs" in config: + return config["runs"] + return [config] + if isinstance(config, list): + return config + raise ValueError("validation_adapter_config must be a list or a dict containing 'runs'.") + + +def build_validation_adapter_runs( + adapter_path: str | None, + adapter_config: Any, + *, + adapter_name: str | None = None, + adapter_strength: float = 1.0, + adapter_mode: str | None = None, +) -> List[ValidationAdapterRun]: + """ + Build adapter run definitions from CLI inputs. + """ + + mode = (adapter_mode or "adapter_only").strip().lower() + if mode not in VALID_ADAPTER_MODES: + raise ValueError(f"Invalid adapter mode '{adapter_mode}'. Expected one of {sorted(VALID_ADAPTER_MODES)}") + + runs: List[ValidationAdapterRun] = [] + seen_slugs: set[str] = set() + + def _make_run(label: str | None, specs: Sequence[ValidationAdapterSpec]) -> ValidationAdapterRun: + primary_spec = specs[0] + if primary_spec.is_local: + fallback_label = _stem_from_path(primary_spec.location) + else: + fallback_label = _stem_from_path(primary_spec.repo_id or "adapter") + display_label = label or primary_spec.adapter_name or fallback_label + slug = _slugify(display_label) + original_slug = slug + counter = 2 + while slug in seen_slugs: + slug = f"{original_slug}_{counter}" + counter += 1 + seen_slugs.add(slug) + return ValidationAdapterRun( + label=display_label, + slug=slug, + adapters=tuple(specs), + is_base=False, + ) + + if adapter_path and mode != "none": + specs = [_build_adapter_spec(adapter_path, adapter_strength, adapter_name)] + preferred_label = adapter_name or _stem_from_path(adapter_path) + runs.append(_make_run(preferred_label, specs)) + + for entry in _iter_config_entries(adapter_config): + label, specs = _normalize_run_entry(entry) + if not specs: + continue + runs.append(_make_run(label, specs)) + + include_base = True + if adapter_path and mode == "adapter_only" and adapter_config in (None, [], {}): + include_base = False + + ordered_runs: List[ValidationAdapterRun] = [] + if include_base: + ordered_runs.append(ValidationAdapterRun.base()) + ordered_runs.extend(runs) + + if not ordered_runs: + ordered_runs.append(ValidationAdapterRun.base()) + + return ordered_runs diff --git a/simpletuner/helpers/utils/offloading.py b/simpletuner/helpers/utils/offloading.py index 908c6c290..ddf0de204 100644 --- a/simpletuner/helpers/utils/offloading.py +++ b/simpletuner/helpers/utils/offloading.py @@ -5,12 +5,13 @@ try: from diffusers.hooks import apply_group_offloading - from diffusers.hooks.group_offloading import _is_group_offload_enabled + from diffusers.hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading _DIFFUSERS_GROUP_OFFLOAD_AVAILABLE = True except ImportError: # pragma: no cover - handled by runtime checks apply_group_offloading = None # type: ignore[assignment] _is_group_offload_enabled = None # type: ignore[assignment] + _maybe_remove_and_reapply_group_offloading = None # type: ignore[assignment] _DIFFUSERS_GROUP_OFFLOAD_AVAILABLE = False @@ -103,3 +104,33 @@ def enable_group_offload_on_components( offload_device=offload_device, **kwargs, ) + + +def unpack_offload_state(offload_state): + """ + Normalize the value returned by diffusers' _optionally_disable_offloading helper. + """ + + if isinstance(offload_state, tuple): + padded = list(offload_state) + [False] * (3 - len(offload_state)) + return bool(padded[0]), bool(padded[1]), bool(padded[2]) + + return bool(offload_state), False, False + + +def restore_offload_state(_pipeline, is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload): + """ + Re-apply the appropriate offloading hooks depending on prior state. + """ + + if _pipeline is None: + return + + if is_model_cpu_offload and hasattr(_pipeline, "enable_model_cpu_offload"): + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload and hasattr(_pipeline, "enable_sequential_cpu_offload"): + _pipeline.enable_sequential_cpu_offload() + elif is_group_offload and _maybe_remove_and_reapply_group_offloading and hasattr(_pipeline, "components"): + for component in _pipeline.components.values(): + if isinstance(component, torch.nn.Module): + _maybe_remove_and_reapply_group_offloading(component) diff --git a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/validation.py b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/validation.py index d008e08b3..530b419ea 100644 --- a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/validation.py +++ b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/validation.py @@ -751,3 +751,94 @@ def register_validation_fields(registry: "FieldRegistry") -> None: subsection="advanced", ) ) + + registry._add_field( + ConfigField( + name="validation_adapter_path", + arg_name="--validation_adapter_path", + ui_label="Validation Adapter Path", + field_type=FieldType.TEXT, + tab="validation", + section="validation_adapters", + default_value=None, + placeholder="repo/id:weights.safetensors or /path/to/adapter.safetensors", + help_text="Temporarily load a single LoRA adapter during validation from a local file or Hugging Face repo.", + tooltip="Formats: 'org/repo:weights.safetensors', 'org/repo' (defaults to pytorch_lora_weights.safetensors) or a local path.", + importance=ImportanceLevel.ADVANCED, + order=1, + ) + ) + + registry._add_field( + ConfigField( + name="validation_adapter_name", + arg_name="--validation_adapter_name", + ui_label="Validation Adapter Name", + field_type=FieldType.TEXT, + tab="validation", + section="validation_adapters", + default_value=None, + placeholder="my_adapter_name", + help_text="Optional adapter identifier to use when loading LoRA weights for validation.", + tooltip="If left blank, SimpleTuner generates a unique adapter name automatically.", + importance=ImportanceLevel.ADVANCED, + order=2, + ) + ) + + registry._add_field( + ConfigField( + name="validation_adapter_strength", + arg_name="--validation_adapter_strength", + ui_label="Validation Adapter Strength", + field_type=FieldType.NUMBER, + tab="validation", + section="validation_adapters", + default_value=1.0, + help_text="Strength multiplier applied when activating the validation adapter.", + tooltip="Values greater than 1 increase the LoRA influence; values between 0 and 1 reduce it.", + importance=ImportanceLevel.ADVANCED, + order=3, + validation_rules=[ + ValidationRule(ValidationRuleType.MIN, value=0, message="Strength must be greater than 0"), + ], + ) + ) + + registry._add_field( + ConfigField( + name="validation_adapter_mode", + arg_name="--validation_adapter_mode", + ui_label="Validation Adapter Comparison", + field_type=FieldType.SELECT, + tab="validation", + section="validation_adapters", + default_value="adapter_only", + choices=[ + {"value": "adapter_only", "label": "Adapter Only"}, + {"value": "comparison", "label": "Comparison"}, + {"value": "none", "label": "Disabled"}, + ], + help_text="Select whether to sample only the adapter, compare against the base model, or skip loading it.", + tooltip="Comparison renders both with and without the adapter so you can review differences.", + importance=ImportanceLevel.ADVANCED, + order=4, + ) + ) + + registry._add_field( + ConfigField( + name="validation_adapter_config", + arg_name="--validation_adapter_config", + ui_label="Validation Adapter Config", + field_type=FieldType.TEXT, + tab="validation", + section="validation_adapters", + default_value=None, + placeholder="/path/to/validation_adapters.json", + help_text="JSON file or inline JSON describing multiple adapter combinations to evaluate during validation.", + tooltip="Each entry can define 'label' and a list of adapter paths so multiple validation runs are automated.", + importance=ImportanceLevel.EXPERIMENTAL, + order=5, + ) + ) diff --git a/simpletuner/simpletuner_sdk/server/services/field_service.py b/simpletuner/simpletuner_sdk/server/services/field_service.py index 6c250fe82..3ca62d483 100644 --- a/simpletuner/simpletuner_sdk/server/services/field_service.py +++ b/simpletuner/simpletuner_sdk/server/services/field_service.py @@ -611,6 +611,26 @@ class FieldService: subsection_override="advanced", order=31, ), + SectionLayout( + id="validation_adapters", + title="Validation Adapters", + icon="fas fa-layer-group", + match_section="validation_adapters", + match_subsections=(None,), + subsection_override="", + order=35, + ), + SectionLayout( + id="validation_adapters_advanced", + title="", + icon="", + advanced=True, + parent="validation_adapters", + match_section="validation_adapters", + match_subsections=("advanced",), + subsection_override="advanced", + order=36, + ), SectionLayout( id="validation_options", title="Validation Options", diff --git a/simpletuner/static/css/trainer.css b/simpletuner/static/css/trainer.css index 0beef8fad..d805c7675 100644 --- a/simpletuner/static/css/trainer.css +++ b/simpletuner/static/css/trainer.css @@ -1244,7 +1244,6 @@ optgroup option { inset: 0; background: rgba(12, 15, 25, 0.55); border-radius: 0.75rem; - display: flex; align-items: center; justify-content: center; padding: 1.25rem; diff --git a/simpletuner/templates/form_tab.html b/simpletuner/templates/form_tab.html index 7404511e5..d4b5e4e7c 100644 --- a/simpletuner/templates/form_tab.html +++ b/simpletuner/templates/form_tab.html @@ -529,12 +529,80 @@