diff --git a/README.md b/README.md index c028ce7e8..8164298aa 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,7 @@ SimpleTuner provides comprehensive training support across multiple diffusion mo - **Gradient checkpointing** - Configurable intervals for memory/speed optimization - **Loss functions** - L2, Huber, Smooth L1 with scheduling support - **SNR weighting** - Min-SNR gamma weighting for improved training dynamics +- **Group offloading** - Diffusers v0.33+ module-group CPU/disk staging with optional CUDA streams ### Model-Specific Features @@ -99,6 +100,7 @@ SimpleTuner provides comprehensive training support across multiple diffusion mo - **T5 masked training** - Enhanced fine details for Flux and compatible models - **QKV fusion** - Memory and speed optimizations (Flux, Lumina2) - **TREAD integration** - Selective token routing for Wan and Flux models +- **Wan 2.x I2V** - High/low stage presets plus a 2.1 time-embedding fallback (see Wan quickstart) - **Classifier-free guidance** - Optional CFG reintroduction for distilled models ### Quickstart Guides diff --git a/documentation/DATALOADER.md b/documentation/DATALOADER.md index d42d27c98..d11a268aa 100644 --- a/documentation/DATALOADER.md +++ b/documentation/DATALOADER.md @@ -49,8 +49,8 @@ Here is the most basic example of a dataloader configuration file, as `multidata ### `dataset_type` -- **Values:** `image` | `video` | `text_embeds` | `image_embeds` | `conditioning` -- **Description:** `image` and `video` datasets contain your training data. `text_embeds` contain the outputs of the text encoder cache, and `image_embeds` contain the VAE outputs, if the model uses one. When a dataset is marked as `conditioning`, it is possible to pair it to your `image` dataset via [the conditioning_data option](#conditioning_data) +- **Values:** `image` | `video` | `text_embeds` | `image_embeds` | `conditioning_image_embeds` | `conditioning` +- **Description:** `image` and `video` datasets contain your training data. `text_embeds` contain the outputs of the text encoder cache, `image_embeds` contain the VAE latents (when a model uses one), and `conditioning_image_embeds` store cached conditioning image embeddings (such as CLIP vision features). When a dataset is marked as `conditioning`, it is possible to pair it to your `image` dataset via [the conditioning_data option](#conditioning_data) - **Note:** Text and image embed datasets are defined differently than image datasets are. A text embed dataset stores ONLY the text embed objects. An image dataset stores the training data. - **Note:** Don't combine images and video in a **single** dataset. Split them out. @@ -69,6 +69,22 @@ Here is the most basic example of a dataloader configuration file, as `multidata - **Only applies to `dataset_type=image`** - If unset, the VAE outputs will be stored on the image backend. Otherwise, you may set this to the `id` of an `image_embeds` dataset, and the VAE outputs will be stored there instead. Allows associating the image_embed dataset to the image data. +### `conditioning_image_embeds` + +- **Applies to `dataset_type=image` and `dataset_type=video`** +- When a model reports `requires_conditioning_image_embeds`, set this to the `id` of a `conditioning_image_embeds` dataset to store cached conditioning image embeddings (for example, CLIP vision features for Wan 2.2 I2V). If unset, SimpleTuner writes the cache to `cache/conditioning_image_embeds/` by default, guaranteeing it no longer collides with the VAE cache. +- Models that need these embeds must expose an image encoder through their primary pipeline. If the model cannot supply the encoder, preprocessing will fail early instead of silently generating empty files. + +#### `cache_dir_conditioning_image_embeds` + +- **Optional override for the conditioning image embed cache destination.** +- Set this when you want to pin the cache to a specific filesystem location or have a dedicated remote backend (`dataset_type=conditioning_image_embeds`). When omitted, the cache path described above is used automatically. + +#### `conditioning_image_embed_batch_size` + +- **Optional override for the batch size used while generating conditioning image embeds.** +- Defaults to the `conditioning_image_embed_batch_size` trainer argument or the VAE batch size when not explicitly provided. + ### `type` - **Values:** `aws` | `local` | `csv` | `huggingface` @@ -430,7 +446,8 @@ In order, the lines behave as follows: "probability": 1.0, "repeats": 0, "text_embeds": "alt-embed-cache", - "image_embeds": "vae-embeds-example" + "image_embeds": "vae-embeds-example", + "conditioning_image_embeds": "conditioning-embeds-example" }, { "id": "another-special-name-for-another-backend", @@ -451,6 +468,12 @@ In order, the lines behave as follows: "dataset_type": "image_embeds", "disabled": false, }, + { + "id": "conditioning-embeds-example", + "type": "local", + "dataset_type": "conditioning_image_embeds", + "disabled": false + }, { "id": "an example backend for text embeds.", "dataset_type": "text_embeds", diff --git a/documentation/OPTIONS.md b/documentation/OPTIONS.md index e341d19f0..ed387dbc7 100644 --- a/documentation/OPTIONS.md +++ b/documentation/OPTIONS.md @@ -52,6 +52,40 @@ Where `foo` is your config environment - or just use `config/config.json` if you - **What**: Offloads text encoder weights to CPU when VAE caching is going. - **Why**: This is useful for large models like HiDream and Wan 2.1, which can OOM when loading the VAE cache. This option does not impact quality of training, but for very large text encoders or slow CPUs, it can extend startup time substantially with many datasets. This is disabled by default due to this reason. +- **Tip**: Complements the group offloading feature below for especially memory-constrained systems. + +### `--enable_group_offload` + +- **What**: Enables diffusers' grouped module offloading so model blocks can be staged on CPU (or disk) between forward passes. +- **Why**: Dramatically reduces peak VRAM usage on large transformers (Flux, Wan, Auraflow, LTXVideo, Cosmos2Image) with minimal performance impact when used with CUDA streams. +- **Notes**: + - Mutually exclusive with `--enable_model_cpu_offload` — pick one strategy per run. + - Requires diffusers **v0.33.0** or newer. + +### `--group_offload_type` + +- **Choices**: `block_level` (default), `leaf_level` +- **What**: Controls how layers are grouped. `block_level` balances VRAM savings with throughput, while `leaf_level` maximises savings at the cost of more CPU transfers. + +### `--group_offload_blocks_per_group` + +- **What**: When using `block_level`, the number of transformer blocks to bundle into a single offload group. +- **Default**: `1` +- **Why**: Increasing this number reduces transfer frequency (faster) but keeps more parameters resident on the accelerator (uses more VRAM). + +### `--group_offload_use_stream` + +- **What**: Uses a dedicated CUDA stream to overlap host/device transfers with compute. +- **Default**: `False` +- **Notes**: + - Automatically falls back to CPU-style transfers on non-CUDA backends (Apple MPS, ROCm, CPU). + - Recommended when training on NVIDIA GPUs with spare copy engine capacity. + +### `--group_offload_to_disk_path` + +- **What**: Directory path used to spill grouped parameters to disk instead of RAM. +- **Why**: Useful for extremely tight CPU RAM budgets (e.g., workstation with large NVMe drive). +- **Tip**: Use a fast local SSD; network filesystems will significantly slow training. ### `--pretrained_model_name_or_path` diff --git a/documentation/QUICKSTART.md b/documentation/QUICKSTART.md index 1f0c6c721..46addece1 100644 --- a/documentation/QUICKSTART.md +++ b/documentation/QUICKSTART.md @@ -23,9 +23,11 @@ For the complete and most accurate feature matrix, please see the [main README.m | [Lumina2](/documentation/quickstart/LUMINA2.md) | 2B | ✓ | ✓ | ✓ | optional (int8) | bf16 | ✓ | ✓ | | | [Cosmos2](/documentation/quickstart/COSMOS2IMAGE.md) | 2B | ✓ | ✓ | ✓ | not recommended | bf16 | ✓ | ✓ | | | [LTX Video](/documentation/quickstart/LTXVIDEO.md)| ~2.5 B | ✓ | ✓ | ✓ | optional (int8,  fp8) | bf16 | ✓ | ✓ | | -| [Wan 2.1](/documentation/quickstart/WAN.md) | 1.3B-14B | ✓ | ✓ | ✓* | optional (int8) | bf16 | ✓ | ✓ | | +| [Wan 2.x](/documentation/quickstart/WAN.md) | 1.3B-14B | ✓ | ✓ | ✓* | optional (int8) | bf16 | ✓ | ✓ | | | [Qwen Image](/documentation/quickstart/QWEN_IMAGE.md) | 20B | ✓ | ✓ | ✓* | required (int8, nf4) | bf16 | ✓ (required) | ✓ | | **Note:** The above table provides a simplified overview. For the complete and most accurate feature matrix with detailed specifications, please see the [main README.md](../README.md#model-architecture-support). +> ℹ️ The Wan quickstart covers 2.1 training plus the 2.2 high/low stage presets and the new time-embedding compatibility toggle. + > ⚠️ These tutorials are a work-in-progress. They contain full end-to-end instructions for a basic training session. diff --git a/documentation/quickstart/AURAFLOW.md b/documentation/quickstart/AURAFLOW.md index 0c4fce297..3b4b4a3ca 100644 --- a/documentation/quickstart/AURAFLOW.md +++ b/documentation/quickstart/AURAFLOW.md @@ -10,6 +10,23 @@ Auraflow v0.3 was released as a 6B parameter MMDiT that uses Pile T5 for its enc This model is somewhat slow for inference, but trains at a decent speed. +### Memory offloading (optional) + +Auraflow benefits greatly from the new grouped offloading path. Add the following to your training flags if you are limited to a single 24G (or smaller) GPU: + +```bash +--enable_group_offload \ +--group_offload_type block_level \ +--group_offload_blocks_per_group 1 \ +--group_offload_use_stream \ +# optional: spill offloaded weights to disk instead of RAM +# --group_offload_to_disk_path /fast-ssd/simpletuner-offload +``` + +- Streams are automatically disabled on non-CUDA backends, so the command is safe to reuse on ROCm and MPS. +- Do not combine this with `--enable_model_cpu_offload`. +- Disk offloading trades throughput for lower host RAM pressure; keep it on a local SSD for best results. + ### Prerequisites Make sure that you have python installed; SimpleTuner does well with 3.10 through 3.12. diff --git a/documentation/quickstart/COSMOS2IMAGE.md b/documentation/quickstart/COSMOS2IMAGE.md index 88d69e6e1..502e0cfaa 100644 --- a/documentation/quickstart/COSMOS2IMAGE.md +++ b/documentation/quickstart/COSMOS2IMAGE.md @@ -10,6 +10,23 @@ Cosmos2 Predict (Image) is a vision transformer-based model that uses flow match A 24GB GPU is recommended as the minimum for comfortable training without extensive optimizations. +### Memory offloading (optional) + +To squeeze Cosmos2 into smaller GPUs, enable grouped offloading: + +```bash +--enable_group_offload \ +--group_offload_type block_level \ +--group_offload_blocks_per_group 1 \ +--group_offload_use_stream \ +# optional: spill offloaded weights to disk instead of RAM +# --group_offload_to_disk_path /fast-ssd/simpletuner-offload +``` + +- Streams are only honoured on CUDA; other devices fall back automatically. +- Do not combine this with `--enable_model_cpu_offload`. +- Disk staging is optional and helps when system RAM is the bottleneck. + ### Prerequisites Make sure that you have python installed; SimpleTuner does well with 3.10 through 3.12. diff --git a/documentation/quickstart/FLUX.md b/documentation/quickstart/FLUX.md index 292bec6ac..9f5c77c0b 100644 --- a/documentation/quickstart/FLUX.md +++ b/documentation/quickstart/FLUX.md @@ -26,6 +26,23 @@ Luckily, these are readily available through providers such as [LambdaLabs](http **Unlike other models, Apple GPUs do not currently work for training Flux.** +### Memory offloading (optional) + +Flux supports grouped module offloading via diffusers v0.33+. This dramatically reduces VRAM pressure when you are bottlenecked by the transformer weights. You can enable it by adding the following flags to `TRAINER_EXTRA_ARGS` (or the WebUI Hardware page): + +```bash +--enable_group_offload \ +--group_offload_type block_level \ +--group_offload_blocks_per_group 1 \ +--group_offload_use_stream \ +# optional: spill offloaded weights to disk instead of RAM +# --group_offload_to_disk_path /fast-ssd/simpletuner-offload +``` + +- `--group_offload_use_stream` is only effective on CUDA devices; SimpleTuner automatically disables streams on ROCm, MPS and CPU backends. +- Do **not** combine this with `--enable_model_cpu_offload` — the two strategies are mutually exclusive. +- When using `--group_offload_to_disk_path`, prefer a fast local SSD/NVMe target. + ## Prerequisites Make sure that you have python installed; SimpleTuner does well with 3.10 through 3.12. diff --git a/documentation/quickstart/LTXVIDEO.md b/documentation/quickstart/LTXVIDEO.md index 5f9b353d5..17d84aeb0 100644 --- a/documentation/quickstart/LTXVIDEO.md +++ b/documentation/quickstart/LTXVIDEO.md @@ -14,6 +14,23 @@ You'll need: Apple silicon systems work great with LTX so far, albeit at a lower resolution due to limits inside the MPS backend used by Pytorch. +### Memory offloading (optional) + +If you are close to the VRAM limit, enable grouped offloading in your config: + +```bash +--enable_group_offload \ +--group_offload_type block_level \ +--group_offload_blocks_per_group 1 \ +--group_offload_use_stream \ +# optional: spill offloaded weights to disk instead of RAM +# --group_offload_to_disk_path /fast-ssd/simpletuner-offload +``` + +- CUDA users benefit from `--group_offload_use_stream`; other backends ignore it automatically. +- Skip `--group_offload_to_disk_path` unless system RAM is <64 GB — disk staging is slower but keeps runs stable. +- Disable `--enable_model_cpu_offload` when using group offloading. + ### Prerequisites Make sure that you have python installed; SimpleTuner does well with 3.10 through 3.12. diff --git a/documentation/quickstart/WAN.md b/documentation/quickstart/WAN.md index cc196f772..8be7a0f5e 100644 --- a/documentation/quickstart/WAN.md +++ b/documentation/quickstart/WAN.md @@ -29,10 +29,30 @@ Currently, image-to-video training is not supported for Wan, but T2V LoRA and Ly - Resolution: 1280x720 --> +#### Image to Video (Wan 2.2) + +Recent Wan 2.2 I2V checkpoints work with the same training flow: + +- High stage: https://huggingface.co/Wan-AI/Wan2.2-I2V-14B-Diffusers/tree/main/high_noise_model +- Low stage: https://huggingface.co/Wan-AI/Wan2.2-I2V-14B-Diffusers/tree/main/low_noise_model + +You can target the stage you want with the `model_flavour` and `wan_validation_load_other_stage` settings outlined later in this guide. + You'll need: - **a realistic minimum** is 16GB or, a single 3090 or V100 GPU - **ideally** multiple 4090, A6000, L40S, or better +If you encounter shape mismatches in the time embedding layers when running Wan 2.2 checkpoints, enable the new +`wan_force_2_1_time_embedding` flag. This forces the transformer to fall back to Wan 2.1 style time embeddings and +resolves the compatibility issue. + +#### Stage presets & validation + +- `model_flavour=i2v-14b-2.2-high` targets the Wan 2.2 high-noise stage. +- `model_flavour=i2v-14b-2.2-low` targets the low-noise stage (same checkpoints, different subfolder). +- Toggle `wan_validation_load_other_stage=true` to load the opposite stage alongside the one you train for validation renders. +- Leave the flavour unset (or use `t2v-480p-1.3b-2.1`) for the standard Wan 2.1 text-to-video run. + Apple silicon systems do not work super well with Wan 2.1 so far, something like 10 minutes for a single training step can be expected.. ### Prerequisites @@ -112,6 +132,23 @@ simpletuner configure > ⚠️ For users located in countries where Hugging Face Hub is not readily accessible, you should add `HF_ENDPOINT=https://hf-mirror.com` to your `~/.bashrc` or `~/.zshrc` depending on which `$SHELL` your system uses. +### Memory offloading (optional) + +Wan is one of the heaviest models SimpleTuner supports. Enable grouped offloading if you are close to the VRAM ceiling: + +```bash +--enable_group_offload \ +--group_offload_type block_level \ +--group_offload_blocks_per_group 1 \ +--group_offload_use_stream \ +# optional: spill offloaded weights to disk instead of RAM +# --group_offload_to_disk_path /fast-ssd/simpletuner-offload +``` + +- Only CUDA devices honour `--group_offload_use_stream`; ROCm/MPS fall back automatically. +- Leave disk staging commented out unless CPU memory is the bottleneck. +- `--enable_model_cpu_offload` is mutually exclusive with group offload. + If you prefer to manually configure: @@ -432,6 +469,30 @@ Create a `--data_backend_config` (`config/multidatabackend.json`) document conta ] ``` +- Wan 2.2 image-to-video runs create CLIP conditioning caches. In the **video** dataset entry, point at a dedicated backend and (optionally) override the cache path: + +```json + { + "id": "disney-black-and-white", + "type": "local", + "dataset_type": "video", + "conditioning_image_embeds": "disney-conditioning", + "cache_dir_conditioning_image_embeds": "cache/conditioning_image_embeds/disney-black-and-white" + } +``` + +- Define the conditioning backend once and reuse it across datasets if needed (full object shown here for clarity): + +```json + { + "id": "disney-conditioning", + "type": "local", + "dataset_type": "conditioning_image_embeds", + "cache_dir": "cache/conditioning_image_embeds/disney-conditioning", + "disabled": false + } +``` + - In the `video` subsection, we have the following keys we can set: - `num_frames` (optional, int) is how many seconds of data we'll train on. - At 15 fps, 75 frames is 5 seconds of video, standard output. This should be your target. @@ -488,6 +549,8 @@ simpletuner train simpletuner train ``` +> ℹ️ Append `--model_flavour i2v-14b-2.2-high` (or `low`) and, if desired, `--wan_validation_load_other_stage` inside `TRAINER_EXTRA_ARGS` or your CLI invocation when you train Wan 2.2. Add `--wan_force_2_1_time_embedding` only when the checkpoint reports a time-embedding shape mismatch. + **Option 3 (Legacy method - still works):** ```bash ./train.sh diff --git a/setup.py b/setup.py index efaffaa76..8ef64e6fd 100644 --- a/setup.py +++ b/setup.py @@ -69,9 +69,7 @@ def build_rocm_wheel_url(package: str, version: str, rocm_version: str) -> str: py_tag = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_tag = _rocm_platform_tag() filename = f"{package}-{version}%2Brocm{rocm_version}-{py_tag}-{py_tag}-{platform_tag}.whl" - base_url = os.environ.get( - "SIMPLETUNER_ROCM_BASE_URL", f"https://download.pytorch.org/whl/rocm{rocm_version}" - ) + base_url = os.environ.get("SIMPLETUNER_ROCM_BASE_URL", f"https://download.pytorch.org/whl/rocm{rocm_version}") return f"{package} @ {base_url}/{filename}" @@ -86,6 +84,7 @@ def get_cuda_dependencies(): "torchao>=0.12.0", "nvidia-cudnn-cu12", "nvidia-nccl-cu12", + "nvidia-ml-py>=12.555", "lm-eval>=0.4.4", ] @@ -183,7 +182,7 @@ def _collect_package_files(*directories: str): "wandb>=0.21.0", "requests>=2.32.4", "pillow>=11.3.0", - "trainingsample>=0.2.1", + "trainingsample>=0.2.10", "accelerate>=1.5.2", "safetensors>=0.5.3", "compel>=2.1.1", @@ -218,7 +217,6 @@ def _collect_package_files(*directories: str): "imageio[pyav]>=2.37.0", "hf-xet>=1.1.5", "peft-singlora>=0.2.0", - "trainingsample>=0.2.1", "cryptography>=41.0.0", ] diff --git a/simpletuner/cli.py b/simpletuner/cli.py index dc644c985..2409859c8 100644 --- a/simpletuner/cli.py +++ b/simpletuner/cli.py @@ -14,6 +14,8 @@ from pathlib import Path from typing import List, Optional +from simpletuner.simpletuner_sdk.server.utils.paths import get_config_directory, get_template_directory + def find_config_file() -> Optional[str]: """Find config file in current directory or config/ subdirectory.""" @@ -609,6 +611,13 @@ def cmd_server(args) -> int: os.environ["SIMPLETUNER_SSL_KEYFILE"] = ssl_config["keyfile"] os.environ["SIMPLETUNER_SSL_CERTFILE"] = ssl_config["certfile"] + # Ensure template resolution points to packaged templates unless overridden + os.environ.setdefault("TEMPLATE_DIR", str(get_template_directory())) + + # Ensure a configuration directory exists and record it for downstream services + config_dir = get_config_directory() + os.environ.setdefault("SIMPLETUNER_CONFIG_DIR", str(config_dir)) + try: import uvicorn @@ -622,12 +631,6 @@ def cmd_server(args) -> int: # Create app with specified mode app = create_app(mode=server_mode, ssl_no_verify=ssl_no_verify) - # Create necessary directories - os.makedirs("static/css", exist_ok=True) - os.makedirs("static/js", exist_ok=True) - os.makedirs("templates", exist_ok=True) - os.makedirs("configs", exist_ok=True) - # Configure uvicorn SSL uvicorn_config = {"app": app, "host": host, "port": port, "reload": reload, "log_level": "info"} diff --git a/simpletuner/configure.py b/simpletuner/configure.py index 086ce96d3..b898b4ab5 100644 --- a/simpletuner/configure.py +++ b/simpletuner/configure.py @@ -36,6 +36,7 @@ "lumina2", "cosmos2image", "qwen_image", + "chroma", ], "lora": [ "flux", @@ -51,6 +52,7 @@ "hidream", "lumina2", "qwen_image", + "chroma", ], "controlnet": [ "sdxl", @@ -62,6 +64,7 @@ "pixart_sigma", "sd3", "kolors", + "chroma", ], } diff --git a/simpletuner/examples/qwen_image.peft-lora/config.json b/simpletuner/examples/qwen_image.peft-lora/config.json index 13c27ae72..6dd07f274 100644 --- a/simpletuner/examples/qwen_image.peft-lora/config.json +++ b/simpletuner/examples/qwen_image.peft-lora/config.json @@ -15,8 +15,8 @@ "lora_alpha": 8, "lora_rank": 8, "lora_type": "standard", - "lr_scheduler": "constant", - "lr_warmup_steps": 100, + "lr_scheduler": "constant_with_warmup", + "lr_warmup_steps": 10, "max_grad_norm": 0.01, "max_train_steps": 100, "minimum_image_size": 0, diff --git a/simpletuner/examples/wan-2.2-i2v-a14b-high.peft-lora+TREAD/config.json b/simpletuner/examples/wan-2.2-i2v-a14b-high.peft-lora+TREAD/config.json new file mode 100644 index 000000000..da1ecc48b --- /dev/null +++ b/simpletuner/examples/wan-2.2-i2v-a14b-high.peft-lora+TREAD/config.json @@ -0,0 +1,68 @@ +{ + "aspect_bucket_rounding": 2, + "attention_mechanism": "diffusers", + "base_model_precision": "int8-torchao", + "caption_dropout_probability": 0.1, + "checkpointing_steps": 100, + "checkpoints_total_limit": 5, + "compress_disk_cache": true, + "data_backend_config": "config/examples/multidatabackend-small-video-480p.json", + "delete_problematic_images": false, + "disable_benchmark": true, + "disable_bucket_pruning": true, + "ema_update_interval": 2, + "grad_clip_method": "value", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": true, + "hub_model_id": "simpletuner-controlnet-wan-i2v-a14b-high", + "ignore_final_epochs": true, + "learning_rate": 6e-5, + "lora_rank": 16, + "lora_type": "standard", + "lr_scheduler": "constant_with_warmup", + "lr_warmup_steps": 50, + "max_grad_norm": 0.1, + "max_train_steps": 100, + "minimum_image_size": 0, + "mixed_precision": "bf16", + "model_family": "wan", + "model_flavour": "i2v-14b-2.2-high", + "model_type": "lora", + "num_train_epochs": 0, + "offload_during_startup": true, + "optimizer": "optimi-lion", + "output_dir": "output/examples/wan-2.2-i2v-a14b-high.peft-lora+TREAD", + "push_checkpoints_to_hub": false, + "push_to_hub": false, + "quantize_via": "cpu", + "report_to": "none", + "resolution": 480, + "resolution_type": "pixel_area", + "resume_from_checkpoint": "latest", + "seed": 42, + "tracker_project_name": "lora-training", + "tracker_run_name": "example-training-run-wan2.2-i2v-high", + "train_batch_size": 2, + "tread_config": { + "routes": [ + { "selection_ratio": 0.1, "start_layer_idx": 2, "end_layer_idx": 8 }, + { "selection_ratio": 0.25, "start_layer_idx": 9, "end_layer_idx": 11 }, + { "selection_ratio": 0.4, "start_layer_idx": 12, "end_layer_idx": 15 }, + { "selection_ratio": 0.25, "start_layer_idx": 16, "end_layer_idx": 23 }, + { "selection_ratio": 0.1, "start_layer_idx": 24, "end_layer_idx": -2 } + ] + }, + "use_ema": false, + "vae_batch_size": 1, + "vae_enable_tiling": true, + "vae_enable_slicing": true, + "validation_guidance": 3.5, + "validation_negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "validation_num_inference_steps": 40, + "validation_num_video_frames": 81, + "validation_prompt": "A bustling street scene transitions into a cinematic sequence where a futuristic motorcycle accelerates through neon-soaked alleys, dust and light scattering in slow motion as the camera swings from wide establishing shots to intimate close-ups.", + "validation_prompt_library": false, + "validation_resolution": "832x480", + "validation_seed": 42, + "validation_steps": 50 +} diff --git a/simpletuner/examples/wan-2.2-i2v-a14b-low.peft-lora+TREAD/config.json b/simpletuner/examples/wan-2.2-i2v-a14b-low.peft-lora+TREAD/config.json new file mode 100644 index 000000000..855e55aac --- /dev/null +++ b/simpletuner/examples/wan-2.2-i2v-a14b-low.peft-lora+TREAD/config.json @@ -0,0 +1,68 @@ +{ + "aspect_bucket_rounding": 2, + "attention_mechanism": "diffusers", + "base_model_precision": "int8-torchao", + "caption_dropout_probability": 0.1, + "checkpointing_steps": 100, + "checkpoints_total_limit": 5, + "compress_disk_cache": true, + "data_backend_config": "config/examples/multidatabackend-small-video-480p.json", + "delete_problematic_images": false, + "disable_benchmark": true, + "disable_bucket_pruning": true, + "ema_update_interval": 2, + "grad_clip_method": "value", + "gradient_accumulation_steps": 1, + "gradient_checkpointing": true, + "hub_model_id": "simpletuner-controlnet-wan-i2v-a14b-low", + "ignore_final_epochs": true, + "learning_rate": 6e-5, + "lora_rank": 16, + "lora_type": "standard", + "lr_scheduler": "constant_with_warmup", + "lr_warmup_steps": 50, + "max_grad_norm": 0.1, + "max_train_steps": 100, + "minimum_image_size": 0, + "mixed_precision": "bf16", + "model_family": "wan", + "model_flavour": "i2v-14b-2.2-low", + "model_type": "lora", + "num_train_epochs": 0, + "offload_during_startup": true, + "optimizer": "optimi-lion", + "output_dir": "output/examples/wan-2.2-i2v-a14b-low.peft-lora+TREAD", + "push_checkpoints_to_hub": false, + "push_to_hub": false, + "quantize_via": "cpu", + "report_to": "none", + "resolution": 480, + "resolution_type": "pixel_area", + "resume_from_checkpoint": "latest", + "seed": 42, + "tracker_project_name": "lora-training", + "tracker_run_name": "example-training-run-wan2.2-i2v-low", + "train_batch_size": 2, + "tread_config": { + "routes": [ + { "selection_ratio": 0.1, "start_layer_idx": 2, "end_layer_idx": 8 }, + { "selection_ratio": 0.25, "start_layer_idx": 9, "end_layer_idx": 11 }, + { "selection_ratio": 0.4, "start_layer_idx": 12, "end_layer_idx": 15 }, + { "selection_ratio": 0.25, "start_layer_idx": 16, "end_layer_idx": 23 }, + { "selection_ratio": 0.1, "start_layer_idx": 24, "end_layer_idx": -2 } + ] + }, + "use_ema": false, + "vae_batch_size": 1, + "vae_enable_tiling": true, + "vae_enable_slicing": true, + "validation_guidance": 3.5, + "validation_negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "validation_num_inference_steps": 40, + "validation_num_video_frames": 81, + "validation_prompt": "A serene dawn breaks over a coastal village before cutting to a tethered drone camera orbiting a surfer as they launch from the first wave of the morning, water spray glistening in the low-angle light.", + "validation_prompt_library": false, + "validation_resolution": "832x480", + "validation_seed": 42, + "validation_steps": 50 +} diff --git a/simpletuner/helpers/caching/image_embed.py b/simpletuner/helpers/caching/image_embed.py new file mode 100644 index 000000000..f3c768cc3 --- /dev/null +++ b/simpletuner/helpers/caching/image_embed.py @@ -0,0 +1,257 @@ +import logging +import os +from hashlib import sha256 +from typing import List, Tuple + +import numpy as np +import torch +from PIL import Image + +from simpletuner.helpers.training import image_file_extensions +from simpletuner.helpers.training.multi_process import rank_info, should_log +from simpletuner.helpers.training.state_tracker import StateTracker + +try: + from simpletuner.helpers.webhooks.mixin import WebhookMixin +except Exception: # pragma: no cover - optional dependency guard + + class WebhookMixin: # type: ignore + """Fallback mixin used when webhook dependencies are unavailable.""" + + def set_webhook_handler(self, webhook_handler): + self.webhook_handler = webhook_handler + + +logger = logging.getLogger("ConditioningImageEmbedCache") +if should_log(): + logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) +else: + logger.setLevel("ERROR") + + +class ImageEmbedCache(WebhookMixin): + def __init__( + self, + id: str, + dataset_type: str, + model, + accelerator, + metadata_backend, + image_data_backend, + cache_data_backend=None, + instance_data_dir: str = "", + cache_dir: str = "", + write_batch_size: int = 16, + read_batch_size: int = 16, + embed_batch_size: int = 4, + hash_filenames: bool = True, + ): + self.id = id + self.dataset_type = dataset_type + self.model = model + self.accelerator = accelerator + self.metadata_backend = metadata_backend + self.image_data_backend = image_data_backend + self.cache_data_backend = cache_data_backend if cache_data_backend is not None else image_data_backend + self.instance_data_dir = instance_data_dir or "" + self.cache_dir = cache_dir or "" + if self.cache_data_backend and self.cache_data_backend.type in ["local", "huggingface"] and self.cache_dir: + self.cache_dir = os.path.abspath(self.cache_dir) + self.cache_data_backend.create_directory(self.cache_dir) + self.write_batch_size = write_batch_size + self.read_batch_size = read_batch_size + self.embed_batch_size = self._ensure_positive_batch_size(embed_batch_size, default=1) + self.hash_filenames = hash_filenames + self.rank_info = rank_info() + + self.webhook_handler = None + + self.image_path_to_embed_path: dict[str, str] = {} + self.embed_path_to_image_path: dict[str, str] = {} + + self.embedder = None + + def debug_log(self, msg: str): + logger.debug(f"{self.rank_info}{msg}") + + def set_webhook_handler(self, webhook_handler): + self.webhook_handler = webhook_handler + + @staticmethod + def _ensure_positive_batch_size(value, default: int) -> int: + try: + batch_size = int(value) + except (TypeError, ValueError): + return default + return batch_size if batch_size > 0 else default + + def _ensure_embedder(self): + if self.embedder is not None: + return + + provider_factory = getattr(self.model, "get_conditioning_image_embedder", None) + if not callable(provider_factory): + raise ValueError( + "Model does not expose a conditioning image embed provider. " + "Ensure the active model implements 'get_conditioning_image_embedder'." + ) + + embedder = provider_factory() + if embedder is None: + raise ValueError("Model reported support for conditioning image embeddings but did not return a provider.") + if not hasattr(embedder, "encode") or not callable(embedder.encode): + raise ValueError("Conditioning image embed provider must implement an 'encode(images)' method.") + self.embedder = embedder + + def generate_embed_filename(self, filepath: str) -> Tuple[str, str]: + if filepath.endswith(".pt"): + return filepath, os.path.basename(filepath) + base_filename = os.path.splitext(os.path.basename(filepath))[0] + if self.hash_filenames: + base_filename = sha256(str(base_filename).encode()).hexdigest() + base_filename = f"{base_filename}.pt" + + subfolders = "" + if self.instance_data_dir: + subfolders = os.path.dirname(filepath).replace(self.instance_data_dir, "", 1) + subfolders = subfolders.lstrip(os.sep) + + if subfolders: + full_filename = os.path.join(self.cache_dir, subfolders, base_filename) + else: + full_filename = os.path.join(self.cache_dir, base_filename) + return full_filename, base_filename + + def build_embed_filename_map(self, all_image_files: List[str]) -> None: + self.image_path_to_embed_path.clear() + self.embed_path_to_image_path.clear() + + for image_file in all_image_files: + cache_filename, _ = self.generate_embed_filename(image_file) + if self.cache_data_backend.type == "local": + cache_filename = os.path.abspath(cache_filename) + self.image_path_to_embed_path[image_file] = cache_filename + self.embed_path_to_image_path[cache_filename] = image_file + + def discover_all_files(self) -> List[str]: + all_image_files = StateTracker.get_image_files(data_backend_id=self.id) or StateTracker.set_image_files( + self.image_data_backend.list_files( + instance_data_dir=self.instance_data_dir, + file_extensions=image_file_extensions, + ), + data_backend_id=self.id, + ) + StateTracker.get_conditioning_image_embed_files(self.id) or StateTracker.set_conditioning_image_embed_files( + self.cache_data_backend.list_files( + instance_data_dir=self.cache_dir, + file_extensions=["pt"], + ), + data_backend_id=self.id, + ) + self.debug_log(f"ConditioningImageEmbedCache discover_all_files found {len(all_image_files)} sources") + return all_image_files + + def discover_unprocessed_files(self) -> List[str]: + if not self.image_path_to_embed_path: + return [] + + pending = [] + for image_path, embed_path in self.image_path_to_embed_path.items(): + test_path = embed_path + if self.cache_data_backend.type == "local": + test_path = os.path.abspath(embed_path) + if not self.cache_data_backend.exists(test_path): + pending.append(image_path) + + return pending + + def _load_image_for_embedding(self, filepath: str) -> Image.Image: + sample = self.image_data_backend.read_image(filepath) + if isinstance(sample, Image.Image): + return sample.convert("RGB") + if isinstance(sample, np.ndarray): + if sample.ndim == 4: + first_frame = sample[0] + elif sample.ndim == 3: + first_frame = sample + else: + raise ValueError(f"Unsupported numpy shape for conditioning embed: {sample.shape}") + if first_frame.dtype != np.uint8: + first_frame = np.clip(first_frame, 0, 255).astype(np.uint8) + return Image.fromarray(first_frame).convert("RGB") + raise ValueError(f"Unsupported sample type for conditioning embed: {type(sample)}") + + def _encode_batch(self, filepaths: List[str]) -> Tuple[List[str], List[torch.Tensor]]: + self._ensure_embedder() + valid_paths: List[str] = [] + images: List[Image.Image] = [] + for fp in filepaths: + try: + images.append(self._load_image_for_embedding(fp)) + valid_paths.append(fp) + except FileNotFoundError: + self.debug_log(f"Skipping missing file during conditioning embed generation: {fp}") + except ValueError as exc: + self.debug_log(f"Skipping unsupported sample {fp}: {exc}") + if not images: + return [], [] + with torch.no_grad(): + embeddings = self.embedder.encode(images) + if embeddings is None: + return [], [] + if isinstance(embeddings, (list, tuple)): + if all(torch.is_tensor(item) for item in embeddings): + embeddings = torch.stack(embeddings, dim=0) + else: + raise ValueError("Conditioning image embed provider returned a sequence containing non-tensors.") + elif isinstance(embeddings, np.ndarray): + embeddings = torch.from_numpy(embeddings) + if not torch.is_tensor(embeddings): + raise ValueError("Conditioning image embed provider returned non-tensor embeddings.") + embeds = embeddings.detach().cpu() + return valid_paths, [embeds[i] for i in range(embeds.shape[0])] + + def _write_embed(self, filepath: str, embedding: torch.Tensor) -> None: + cache_path = self.image_path_to_embed_path.get(filepath) + if cache_path is None: + cache_path, _ = self.generate_embed_filename(filepath) + self.image_path_to_embed_path[filepath] = cache_path + self.embed_path_to_image_path[cache_path] = filepath + + directory = os.path.dirname(cache_path) + if directory and self.cache_data_backend.type == "local": + os.makedirs(directory, exist_ok=True) + + self.cache_data_backend.torch_save(embedding, cache_path) + + current_cache = StateTracker.get_conditioning_image_embed_files(self.id) + if isinstance(current_cache, dict): + current_cache[cache_path] = True + + def process_files(self, filepaths: List[str]) -> None: + if not filepaths: + return + for idx in range(0, len(filepaths), self.embed_batch_size): + batch_paths = filepaths[idx : idx + self.embed_batch_size] + valid_paths, embeddings = self._encode_batch(batch_paths) + for fp, embed in zip(valid_paths, embeddings): + self._write_embed(fp, embed) + if self.cache_dir: + StateTracker.set_conditioning_image_embed_files( + self.cache_data_backend.list_files( + instance_data_dir=self.cache_dir, + file_extensions=["pt"], + ), + data_backend_id=self.id, + ) + + def retrieve_from_cache(self, filepath: str) -> torch.Tensor: + if filepath not in self.image_path_to_embed_path: + cache_path, _ = self.generate_embed_filename(filepath) + self.image_path_to_embed_path[filepath] = cache_path + self.embed_path_to_image_path[cache_path] = filepath + + cache_path = self.image_path_to_embed_path[filepath] + if not self.cache_data_backend.exists(cache_path): + self.process_files([filepath]) + return self.cache_data_backend.torch_load(cache_path) diff --git a/simpletuner/helpers/caching/text_embeds.py b/simpletuner/helpers/caching/text_embeds.py index c543def45..25b94c310 100644 --- a/simpletuner/helpers/caching/text_embeds.py +++ b/simpletuner/helpers/caching/text_embeds.py @@ -247,11 +247,12 @@ def compute_embeddings_for_prompts( ): if self.model.TEXT_ENCODER_CONFIGURATION == {}: # This is a model that doesn't use text encoders. + self.debug_log(f"Model type {self.model_type} does not use text encoders, skipping text embed caching.") self.disabled = True return None - logger.debug("Initialising text embed calculator...") + self.debug_log("Initialising text embed calculator...") if not self.batch_write_thread.is_alive(): - logger.debug("Restarting background write thread.") + self.debug_log("Restarting background write thread.") # Start the thread again. self.process_write_batches = True self.batch_write_thread = Thread(target=self.batch_write_embeddings) diff --git a/simpletuner/helpers/caching/vae.py b/simpletuner/helpers/caching/vae.py index 0d4199d51..215ed86fc 100644 --- a/simpletuner/helpers/caching/vae.py +++ b/simpletuner/helpers/caching/vae.py @@ -536,7 +536,7 @@ def encode_images(self, images, filepaths, load_from_cache=True): ) processed_images = self.prepare_video_latents(processed_images) processed_images = self.model.pre_vae_encode_transform_sample(processed_images) - latents_uncached = self.vae.encode(processed_images) + latents_uncached = self.model.encode_with_vae(self.vae, processed_images) latents_uncached = self.model.post_vae_encode_transform_sample(latents_uncached) if StateTracker.get_model_family() in ["wan"]: diff --git a/simpletuner/helpers/configuration/env_file.py b/simpletuner/helpers/configuration/env_file.py index dcae3769b..185f232a0 100644 --- a/simpletuner/helpers/configuration/env_file.py +++ b/simpletuner/helpers/configuration/env_file.py @@ -175,8 +175,6 @@ def load_env(): def load_env_config(): mapped_args = [] ignored_accelerate_kwargs = [ - "--num_processes", - "--num_machines", "--dynamo_backend", ] for env_var, arg_name in env_to_args_map.items(): diff --git a/simpletuner/helpers/data_backend/config/__init__.py b/simpletuner/helpers/data_backend/config/__init__.py index f360b7e1f..d83a60eb1 100644 --- a/simpletuner/helpers/data_backend/config/__init__.py +++ b/simpletuner/helpers/data_backend/config/__init__.py @@ -6,11 +6,35 @@ from .image_embed import ImageEmbedBackendConfig from .text_embed import TextEmbedBackendConfig +try: # pragma: no cover - graceful fallback when optional module missing + from .conditioning_image_embed import ConditioningImageEmbedBackendConfig +except ModuleNotFoundError: # pragma: no cover - legacy environments + + class ConditioningImageEmbedBackendConfig(ImageEmbedBackendConfig): # type: ignore[misc] + """Fallback configuration that mirrors ImageEmbed when the specialised class is unavailable.""" + + def __post_init__(self): + super().__post_init__() + self.dataset_type = "conditioning_image_embeds" + + @classmethod + def from_dict(cls, backend_dict: dict, args: dict) -> "ConditioningImageEmbedBackendConfig": + config = super().from_dict(backend_dict, args) + config.dataset_type = "conditioning_image_embeds" + return config + + def to_dict(self) -> dict: + payload = super().to_dict() + payload["dataset_type"] = "conditioning_image_embeds" + return payload + + __all__ = [ "BaseBackendConfig", "ImageBackendConfig", "TextEmbedBackendConfig", "ImageEmbedBackendConfig", + "ConditioningImageEmbedBackendConfig", "validators", "create_backend_config", ] @@ -23,6 +47,8 @@ def create_backend_config(backend_dict: dict, args: dict) -> BaseBackendConfig: return TextEmbedBackendConfig.from_dict(backend_dict, args) elif dataset_type == "image_embeds": return ImageEmbedBackendConfig.from_dict(backend_dict, args) + elif dataset_type == "conditioning_image_embeds": + return ConditioningImageEmbedBackendConfig.from_dict(backend_dict, args) elif dataset_type in ["image", "conditioning", "eval", "video"]: return ImageBackendConfig.from_dict(backend_dict, args) else: diff --git a/simpletuner/helpers/data_backend/factory.py b/simpletuner/helpers/data_backend/factory.py index badeed2bd..aaa03cc97 100644 --- a/simpletuner/helpers/data_backend/factory.py +++ b/simpletuner/helpers/data_backend/factory.py @@ -70,6 +70,7 @@ def _coerce_bucket_keys(indices: Dict[Any, Iterable]) -> Dict[Any, list]: import torch from tqdm import tqdm +from simpletuner.helpers.caching.image_embed import ImageEmbedCache from simpletuner.helpers.caching.text_embeds import TextEmbeddingCache from simpletuner.helpers.caching.vae import VAECache from simpletuner.helpers.data_backend.aws import S3DataBackend @@ -222,6 +223,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: elif backend.get("dataset_type", None) == "image_embeds": # no overrides for image embed backends return output + elif backend.get("dataset_type", None) == "conditioning_image_embeds": + return output else: if "caption_filter_list" in backend: raise ValueError( @@ -908,7 +911,7 @@ def get_aws_backend( metrics_logger = logging.getLogger("DataBackendMetrics") if should_log(): - metrics_logger.setLevel(os.environ.get("SIMPLETUNER_METRICS_LOG_LEVEL", "INFO")) + metrics_logger.setLevel(os.environ.get("SIMPLETUNER_METRICS_LOG_LEVEL", "WARNING")) else: metrics_logger.setLevel(logging.ERROR) @@ -950,6 +953,7 @@ def __init__(self, args: Any, accelerator: Any, text_encoders: Any, tokenizers: self.text_embed_backends = {} self.image_embed_backends = {} + self.conditioning_image_embed_backends = {} self.data_backends = {} self.default_text_embed_backend_id = None @@ -958,7 +962,12 @@ def __init__(self, args: Any, accelerator: Any, text_encoders: Any, tokenizers: "factory_type": "new", "initialization_time": 0, "memory_usage": {"start": 0, "peak": 0, "end": 0}, - "backend_counts": {"text_embeds": 0, "image_embeds": 0, "data_backends": 0}, + "backend_counts": { + "text_embeds": 0, + "image_embeds": 0, + "conditioning_image_embeds": 0, + "data_backends": 0, + }, "configuration_time": 0, "build_time": 0, } @@ -1365,6 +1374,55 @@ def configure_image_embed_backends(self, data_backend_config: List[Dict[str, Any init_backend["vaecache"].discover_all_files() self.image_embed_backends[init_backend["id"]] = init_backend + self.metrics["backend_counts"]["image_embeds"] = len(self.image_embed_backends) + + def configure_conditioning_image_embed_backends(self, data_backend_config: List[Dict[str, Any]]) -> None: + """Configure conditioning image embedding backends.""" + for backend in data_backend_config: + dataset_type = backend.get("dataset_type", None) + if dataset_type is None or dataset_type != "conditioning_image_embeds": + continue + + if backend.get("disabled", False) or backend.get("disable", False): + info_log(f"Skipping disabled conditioning image embed backend {backend['id']} in config file.") + continue + + info_log(f'Configuring conditioning image embed backend: {backend["id"]}') + + config = create_backend_config(backend, vars(self.args)) + config.validate(vars(self.args)) + + init_backend = init_backend_config(backend, self.args, self.accelerator) + existing_config = StateTracker.get_data_backend_config(init_backend["id"]) + if existing_config is not None and existing_config != {}: + raise ValueError(f"You can only have one backend named {init_backend['id']}") + StateTracker.set_data_backend_config(init_backend["id"], init_backend["config"]) + + if backend["type"] == "local": + config = create_backend_config(backend, vars(self.args)) + builder = create_backend_builder(backend["type"], self.accelerator, self.args) + init_backend["data_backend"] = builder.build(config) + elif backend["type"] == "aws": + config = create_backend_config(backend, vars(self.args)) + builder = create_backend_builder(backend["type"], self.accelerator, self.args) + init_backend["data_backend"] = builder.build(config) + elif backend["type"] == "csv": + raise ValueError("Cannot use CSV backend for conditioning image embed storage.") + else: + raise ValueError(f"Unknown data backend type: {backend['type']}") + + init_backend["cache_dir"] = backend.get("cache_dir", self.args.cache_dir_vae) + preserve_data_backend_cache = backend.get("preserve_data_backend_cache", False) + if not preserve_data_backend_cache and self.accelerator.is_local_main_process: + StateTracker.delete_cache_files( + data_backend_id=init_backend["id"], + preserve_data_backend_cache=preserve_data_backend_cache, + ) + + self.conditioning_image_embed_backends[init_backend["id"]] = init_backend + + self.metrics["backend_counts"]["conditioning_image_embeds"] = len(self.conditioning_image_embed_backends) + def _prevalidate_backend_ids(self, data_backend_config: List[Dict[str, Any]]) -> None: """Validate that data backends provide unique, non-empty identifiers before configuration.""" try: @@ -1537,6 +1595,18 @@ def _get_image_embed_backend(self, backend: Dict[str, Any], init_backend: Dict[s image_embed_data_backend = self.image_embed_backends[image_embed_backend_id] return image_embed_data_backend + def _get_conditioning_image_embed_backend(self, backend: Dict[str, Any], init_backend: Dict[str, Any]) -> Dict[str, Any]: + """Get the conditioning image embed backend or use the main backend.""" + conditioning_embed_backend_id = backend.get("conditioning_image_embeds", None) + conditioning_embed_backend = init_backend + if conditioning_embed_backend_id is not None: + if conditioning_embed_backend_id not in self.conditioning_image_embed_backends: + raise ValueError( + f"Could not find conditioning image embed backend ID in multidatabackend config: {conditioning_embed_backend_id}" + ) + conditioning_embed_backend = self.conditioning_image_embed_backends[conditioning_embed_backend_id] + return conditioning_embed_backend + def _configure_metadata_backend(self, backend: Dict[str, Any], init_backend: Dict[str, Any]) -> None: """Configure the metadata backend.""" info_log(f"(id={init_backend['id']}) Loading bucket manager.") @@ -1613,9 +1683,7 @@ def _configure_metadata_backend(self, backend: Dict[str, Any], init_backend: Dic metadata_backend = init_backend["metadata_backend"] if isinstance(getattr(metadata_backend, "aspect_ratio_bucket_indices", None), dict): - metadata_backend.aspect_ratio_bucket_indices = _coerce_bucket_keys( - metadata_backend.aspect_ratio_bucket_indices - ) + metadata_backend.aspect_ratio_bucket_indices = _coerce_bucket_keys(metadata_backend.aspect_ratio_bucket_indices) if hasattr(metadata_backend, "_mock_children"): children = getattr(metadata_backend, "_mock_children", None) if isinstance(children, dict): @@ -1961,6 +2029,86 @@ def _handle_auto_generated_dataset(self, backend: Dict[str, Any], init_backend: ) generator.generate_dataset() + def _configure_conditioning_image_embed_cache( + self, + backend: Dict[str, Any], + init_backend: Dict[str, Any], + conditioning_image_embed_backend: Dict[str, Any], + ) -> None: + """Configure conditioning image embed cache for the backend.""" + cache_dir = backend.get("cache_dir_conditioning_image_embeds") + conditioning_backend_dir = conditioning_image_embed_backend.get("cache_dir") + if not cache_dir and conditioning_backend_dir and conditioning_backend_dir != backend.get("cache_dir_vae"): + cache_dir = conditioning_backend_dir + if not cache_dir: + default_root = getattr(self.args, "cache_dir", os.path.join(os.getcwd(), "cache")) + cache_dir = os.path.join(default_root, "conditioning_image_embeds", init_backend["id"]) + + info_log(f"(id={init_backend['id']}) Creating conditioning image embed cache: cache_dir={cache_dir}") + + conditioning_embed_batch_size = backend.get( + "conditioning_image_embed_batch_size", + getattr(self.args, "conditioning_image_embed_batch_size", self.args.vae_batch_size), + ) + + init_backend["conditioning_image_embed_cache"] = ImageEmbedCache( + id=init_backend["id"], + dataset_type=init_backend["dataset_type"], + model=self.model, + accelerator=self.accelerator, + metadata_backend=init_backend["metadata_backend"], + image_data_backend=init_backend["data_backend"], + cache_data_backend=conditioning_image_embed_backend["data_backend"], + instance_data_dir=init_backend.get("instance_data_dir", ""), + cache_dir=cache_dir, + write_batch_size=backend.get("write_batch_size", self.args.write_batch_size), + read_batch_size=backend.get("read_batch_size", self.args.read_batch_size), + embed_batch_size=conditioning_embed_batch_size, + hash_filenames=init_backend["config"].get("hash_filenames", True), + ) + init_backend["conditioning_image_embed_cache"].set_webhook_handler(StateTracker.get_webhook_handler()) + + if self.accelerator.is_local_main_process: + try: + init_backend["conditioning_image_embed_cache"].discover_all_files() + except FileNotFoundError: + warning_log( + f"(id={init_backend['id']}) Skipping conditioning image embed cache discovery because data directory was not found: {init_backend.get('instance_data_dir')}" + ) + return + if self._is_multi_process(): + self.accelerator.wait_for_everyone() + + all_image_files = StateTracker.get_image_files( + data_backend_id=init_backend["id"], + retry_limit=3, + ) + if all_image_files is None: + from simpletuner.helpers.training import image_file_extensions + + logger.debug("No image file cache available, retrieving fresh for conditioning embeds") + try: + all_image_files = init_backend["data_backend"].list_files( + instance_data_dir=init_backend["instance_data_dir"], + file_extensions=image_file_extensions, + ) + except FileNotFoundError: + warning_log( + f"(id={init_backend['id']}) Skipping conditioning embed cache file discovery because data directory was not found: {init_backend.get('instance_data_dir')}" + ) + return + all_image_files = StateTracker.set_image_files(all_image_files, data_backend_id=init_backend["id"]) + + init_backend["conditioning_image_embed_cache"].build_embed_filename_map(all_image_files=all_image_files) + + if not self.args.vae_cache_ondemand: + pending_files = init_backend["conditioning_image_embed_cache"].discover_unprocessed_files() + logger.info(f"Conditioning image embed cache has {len(pending_files)} unprocessed files.") + if pending_files: + init_backend["conditioning_image_embed_cache"].process_files(pending_files) + if self._is_multi_process(): + self.accelerator.wait_for_everyone() + def _configure_vae_cache( self, backend: Dict[str, Any], @@ -2212,16 +2360,19 @@ def _configure_single_data_backend( data_backend_is_mock = hasattr(init_backend["data_backend"], "_mock_children") image_embed_data_backend = self._get_image_embed_backend(backend, init_backend) + conditioning_image_embed_backend = self._get_conditioning_image_embed_backend(backend, init_backend) self._configure_metadata_backend(backend, init_backend) metadata_backend_is_mock = hasattr(init_backend["metadata_backend"], "_mock_children") + # Register early so downstream steps (e.g. caption handling) can locate the metadata backend. + StateTracker.register_data_backend(init_backend) + if data_backend_is_mock and not metadata_backend_is_mock: info_log( f"(id={init_backend['id']}) Detected mocked data backend without mocked metadata backend; skipping runtime setup steps." ) - StateTracker.register_data_backend(init_backend) self.data_backends[init_backend["id"]] = init_backend return @@ -2255,6 +2406,17 @@ def _configure_single_data_backend( conditioning_type, ) + if ( + self.model.requires_conditioning_image_embeds() + and init_backend.get("dataset_type") in ["image", "video"] + and conditioning_type not in ["mask"] + ): + self._configure_conditioning_image_embed_cache( + backend, + init_backend, + conditioning_image_embed_backend, + ) + self._handle_error_scanning_and_metadata(backend, init_backend, conditioning_type) if ( @@ -2305,6 +2467,7 @@ def configure(self, data_backend_config: Optional[List[Dict[str, Any]]] = None) - 'data_backends': Dictionary of configured data backends - 'text_embed_backends': Dictionary of text embedding backends - 'image_embed_backends': Dictionary of image embedding backends + - 'conditioning_image_embed_backends': Dictionary of conditioning image embedding backends - 'default_text_embed_backend_id': ID of default text embedding backend Example: @@ -2319,12 +2482,14 @@ def configure(self, data_backend_config: Optional[List[Dict[str, Any]]] = None) self.configure_text_embed_backends(data_backend_config) self.configure_image_embed_backends(data_backend_config) + self.configure_conditioning_image_embed_backends(data_backend_config) self.configure_data_backends(data_backend_config) result = { "data_backends": StateTracker.get_data_backends(), "text_embed_backends": self.text_embed_backends, "image_embed_backends": self.image_embed_backends, + "conditioning_image_embed_backends": self.conditioning_image_embed_backends, "default_text_embed_backend_id": self.default_text_embed_backend_id, } @@ -2334,6 +2499,7 @@ def configure(self, data_backend_config: Optional[List[Dict[str, Any]]] = None) "backends_configured": len(result["data_backends"]), "text_embed_backends": len(result["text_embed_backends"]), "image_embed_backends": len(result["image_embed_backends"]), + "conditioning_image_embed_backends": len(result["conditioning_image_embed_backends"]), }, ) @@ -2478,6 +2644,7 @@ def configure_multi_databackend_new( factory.configure_text_embed_backends(data_backend_config) factory.configure_image_embed_backends(data_backend_config) + factory.configure_conditioning_image_embed_backends(data_backend_config) factory.configure_data_backends(data_backend_config) factory._log_performance_metrics("implementation_complete") @@ -2486,6 +2653,7 @@ def configure_multi_databackend_new( "data_backends": StateTracker.get_data_backends(), "text_embed_backends": factory.text_embed_backends, "image_embed_backends": factory.image_embed_backends, + "conditioning_image_embed_backends": factory.conditioning_image_embed_backends, "default_text_embed_backend_id": factory.default_text_embed_backend_id, } diff --git a/simpletuner/helpers/data_backend/huggingface.py b/simpletuner/helpers/data_backend/huggingface.py index 621b96fed..99587405f 100644 --- a/simpletuner/helpers/data_backend/huggingface.py +++ b/simpletuner/helpers/data_backend/huggingface.py @@ -269,9 +269,15 @@ def _metadata_with_trainingsample(self, sample: Any) -> Dict[str, Any]: video_path, temp_path = self._prepare_video_source(sample) if not video_path: return metadata - + reader = None capture = None try: + try: + reader = VideoReader(video_path, "video") + metadata.update(self._metadata_from_video_reader(reader)) + except Exception as exc: + logger.debug("Failed to extract metadata with torchvision VideoReader: %s", exc) + import trainingsample as tsr capture = tsr.PyVideoCapture(video_path) @@ -316,6 +322,11 @@ def _metadata_with_trainingsample(self, sample: Any) -> Dict[str, Any]: capture.release() except Exception: pass + if reader is not None: + try: + reader.close() + except Exception: + pass if temp_path: try: os.remove(temp_path) @@ -772,14 +783,35 @@ def get_abs_path(self, sample_path: str) -> str: def read_image(self, filepath: str, delete_problematic_images: bool = False): try: + if self.dataset_type == "video": + index = self._get_index_from_path(filepath) + if index is None: + logger.error("Unable to resolve dataset index for %s", filepath) + return None + item = self.get_dataset_item(index) + if item is None: + return None + sample = item.get(self.video_column) + if sample is None: + logger.error("Dataset item %s missing '%s' column", filepath, self.video_column) + return None + video_path, temp_path = self._prepare_video_source(sample) + if not video_path: + logger.error("Unable to prepare video source for %s", filepath) + return None + try: + return load_video(video_path) + finally: + if temp_path: + try: + os.remove(temp_path) + except OSError: + pass + image_data = self.read(filepath, as_byteIO=True) if image_data is None: return None - loader = load_image - if self.dataset_type == "video": - loader = load_video - image = loader(image_data) - return image + return load_image(image_data) except Exception as e: logger.error(f"Error opening image {filepath}: {e}") if delete_problematic_images: diff --git a/simpletuner/helpers/image_manipulation/load.py b/simpletuner/helpers/image_manipulation/load.py index 5496383a4..899045e0d 100644 --- a/simpletuner/helpers/image_manipulation/load.py +++ b/simpletuner/helpers/image_manipulation/load.py @@ -190,12 +190,14 @@ def load_video(vid_data: Union[bytes, IO[Any], str]) -> np.ndarray: else: raise TypeError("Unsupported type for vid_data. Expected str, bytes, or file-like object.") - # Open the video using VideoCapture. cap = tsr.PyVideoCapture(video_path) if not cap.is_opened(): if tmp_path: os.remove(tmp_path) - raise ValueError("Failed to open video.") + raise ValueError( + f"Failed to open video with trainingsample at '{video_path}'. Ensure trainingsample was built with video " + "support (ffmpeg) and that the asset is a supported format." + ) frames = [] while True: @@ -212,7 +214,10 @@ def load_video(vid_data: Union[bytes, IO[Any], str]) -> np.ndarray: os.remove(tmp_path) if not frames: - raise ValueError("No frames were read from the video.") + raise ValueError( + "No frames were read from the video using trainingsample. Verify ffmpeg support is available and the " + "video is not corrupted." + ) # Stack frames into a numpy array: shape (num_frames, height, width, channels) video_array = np.stack(frames, axis=0) diff --git a/simpletuner/helpers/log_format.py b/simpletuner/helpers/log_format.py index d23a7e6d2..220e85d62 100644 --- a/simpletuner/helpers/log_format.py +++ b/simpletuner/helpers/log_format.py @@ -121,6 +121,10 @@ def format(self, record): torchdistlogger.setLevel("WARNING") torch_utils_logger = logging.getLogger("diffusers.utils.torch_utils") torch_utils_logger.setLevel("ERROR") +starlette_sse_logger = logging.getLogger("sse_starlette.sse") +starlette_sse_logger.setLevel("WARNING") +py_multipart_logger = logging.getLogger("python_multipart.multipart") +py_multipart_logger.setLevel("WARNING") # Suppress specific PIL warning warnings.filterwarnings( diff --git a/simpletuner/helpers/models/auraflow/controlnet.py b/simpletuner/helpers/models/auraflow/controlnet.py index 266ed2309..ec7373bf4 100644 --- a/simpletuner/helpers/models/auraflow/controlnet.py +++ b/simpletuner/helpers/models/auraflow/controlnet.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention_processor import ( Attention, @@ -69,7 +69,6 @@ class AuraFlowControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri "AuraFlowPatchEmbed", ] - @register_to_config def __init__( self, sample_size: int = 64, @@ -89,7 +88,23 @@ def __init__( super().__init__() default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels + effective_out_channels = out_channels if out_channels is not None else default_out_channels + self.register_to_config( + sample_size=sample_size, + patch_size=patch_size, + in_channels=in_channels, + num_mmdit_layers=num_mmdit_layers, + num_single_dit_layers=num_single_dit_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + caption_projection_dim=caption_projection_dim, + out_channels=effective_out_channels, + pos_embed_max_size=pos_embed_max_size, + num_layers=num_layers, + extra_conditioning_channels=extra_conditioning_channels, + ) + self.out_channels = effective_out_channels self.inner_dim = num_attention_heads * attention_head_dim # limit blocks if num_layers specified diff --git a/simpletuner/helpers/models/auraflow/model.py b/simpletuner/helpers/models/auraflow/model.py index c3625ea3e..0bcf0b9ca 100644 --- a/simpletuner/helpers/models/auraflow/model.py +++ b/simpletuner/helpers/models/auraflow/model.py @@ -314,6 +314,14 @@ def controlnet_predict(self, prepared_batch: dict) -> dict: return {"model_prediction": model_pred} + def get_group_offload_components(self, pipeline): + components = dict(super().get_group_offload_components(pipeline)) + if "transformer" not in components and getattr(self, "model", None) is not None: + components["transformer"] = self.unwrap_model(self.model) + if self.config.controlnet and "controlnet" not in components and getattr(self, "controlnet", None) is not None: + components["controlnet"] = self.unwrap_model(self.controlnet) + return components + def get_lora_target_layers(self): if self.config.model_type == "lora" and (self.config.controlnet or self.config.control): controlnet_block_modules = [f"controlnet_blocks.{i}" for i in range(36)] diff --git a/simpletuner/helpers/models/auraflow/transformer.py b/simpletuner/helpers/models/auraflow/transformer.py index bd3434b22..23b2af67c 100644 --- a/simpletuner/helpers/models/auraflow/transformer.py +++ b/simpletuner/helpers/models/auraflow/transformer.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention_processor import ( Attention, @@ -327,7 +327,6 @@ class AuraFlowTransformer2DModel(PatchableModule, ModelMixin, ConfigMixin, PeftA _tread_router: Optional[TREADRouter] = None _tread_routes: Optional[List[Dict[str, Any]]] = None - @register_to_config def __init__( self, sample_size: int = 64, @@ -344,7 +343,21 @@ def __init__( ): super().__init__() default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels + effective_out_channels = out_channels if out_channels is not None else default_out_channels + self.register_to_config( + sample_size=sample_size, + patch_size=patch_size, + in_channels=in_channels, + num_mmdit_layers=num_mmdit_layers, + num_single_dit_layers=num_single_dit_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + caption_projection_dim=caption_projection_dim, + out_channels=effective_out_channels, + pos_embed_max_size=pos_embed_max_size, + ) + self.out_channels = effective_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = AuraFlowPatchEmbed( diff --git a/simpletuner/helpers/models/chroma/controlnet.py b/simpletuner/helpers/models/chroma/controlnet.py index 2cdf5612b..7e455f252 100644 --- a/simpletuner/helpers/models/chroma/controlnet.py +++ b/simpletuner/helpers/models/chroma/controlnet.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import PeftAdapterMixin from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module @@ -17,6 +17,7 @@ ChromaCombinedTimestepTextProjEmbeddings, ChromaSingleTransformerBlock, ChromaTransformerBlock, + adjust_rotary_embedding_dim, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -31,7 +32,6 @@ class ChromaControlNetOutput(BaseOutput): class ChromaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True - @register_to_config def __init__( self, patch_size: int = 1, @@ -48,6 +48,20 @@ def __init__( conditioning_embedding_channels: Optional[int] = None, ): super().__init__() + self.register_to_config( + patch_size=patch_size, + in_channels=in_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + axes_dims_rope=axes_dims_rope, + approximator_num_channels=approximator_num_channels, + approximator_hidden_dim=approximator_hidden_dim, + approximator_layers=approximator_layers, + conditioning_embedding_channels=conditioning_embedding_channels, + ) self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -269,8 +283,15 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) + target_rotary_dim = getattr(self.config, "attention_head_dim", None) + if target_rotary_dim is not None: + image_rotary_emb = adjust_rotary_embedding_dim(image_rotary_emb, int(target_rotary_dim)) txt_len = encoder_hidden_states.shape[1] + if txt_len > 0 and image_rotary_emb[0].shape[0] >= txt_len: + image_only_rotary_emb = tuple(r[txt_len:] for r in image_rotary_emb) + else: + image_only_rotary_emb = image_rotary_emb block_samples: Tuple[torch.Tensor, ...] = () current_hidden_states = hidden_states @@ -326,7 +347,7 @@ def forward( current_hidden_states = block( hidden_states=current_hidden_states, temb=temb, - image_rotary_emb=image_rotary_emb, + image_rotary_emb=image_only_rotary_emb, attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs, ) diff --git a/simpletuner/helpers/models/chroma/model.py b/simpletuner/helpers/models/chroma/model.py index 54d01711f..1dd9646f7 100644 --- a/simpletuner/helpers/models/chroma/model.py +++ b/simpletuner/helpers/models/chroma/model.py @@ -5,7 +5,7 @@ import torch from diffusers import AutoencoderKL from torch.nn import functional as F -from transformers import T5EncoderModel, T5TokenizerFast +from transformers import AutoTokenizer, T5EncoderModel from simpletuner.helpers.configuration.registry import ( ConfigRegistry, @@ -60,7 +60,7 @@ class Chroma(ImageModelFoundation): TEXT_ENCODER_CONFIGURATION = { "text_encoder": { "name": "T5 XXL v1.1", - "tokenizer": T5TokenizerFast, + "tokenizer": AutoTokenizer, "tokenizer_subfolder": "tokenizer", "model": T5EncoderModel, }, @@ -129,7 +129,7 @@ def _encode_prompts(self, prompts: List[str], is_negative_prompt: bool = False): negative_prompt=None, device=self.accelerator.device, num_images_per_prompt=1, - max_sequence_length=int(self.config.tokenizer_max_length), + max_sequence_length=int(self.config.tokenizer_max_length or 512), do_classifier_free_guidance=False, ) if getattr(self.config, "t5_padding", "unmodified") == "zero": @@ -152,20 +152,42 @@ def collate_prompt_embeds(self, text_encoder_output: List[Dict[str, torch.Tensor } def convert_text_embed_for_pipeline(self, text_embedding: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + prompt_embeds = text_embedding["prompt_embeds"] + attention_mask = text_embedding["attention_masks"] + + if prompt_embeds.dim() == 2: + prompt_embeds = prompt_embeds.unsqueeze(0) + if attention_mask.dim() == 3 and attention_mask.size(1) == 1: + attention_mask = attention_mask.squeeze(1) + if attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + return { - "prompt_embeds": text_embedding["prompt_embeds"].unsqueeze(0), - "prompt_attention_mask": text_embedding["attention_masks"].unsqueeze(0), + "prompt_embeds": prompt_embeds, + "prompt_attention_mask": attention_mask, } def convert_negative_text_embed_for_pipeline(self, text_embedding: Dict[str, torch.Tensor], prompt: str) -> dict: - if self.config.validation_guidance_real is None or self.config.validation_guidance_real <= 1.0: - return {} - return { - "negative_prompt_embeds": text_embedding["prompt_embeds"].unsqueeze(0), - "negative_prompt_attention_mask": text_embedding["attention_masks"].unsqueeze(0), - "guidance_scale_real": float(self.config.validation_guidance_real), + neg_embeds = text_embedding["prompt_embeds"] + neg_mask = text_embedding["attention_masks"] + + if neg_embeds.dim() == 2: + neg_embeds = neg_embeds.unsqueeze(0) + if neg_mask.dim() == 3 and neg_mask.size(1) == 1: + neg_mask = neg_mask.squeeze(1) + if neg_mask.dim() == 1: + neg_mask = neg_mask.unsqueeze(0) + + result = { + "negative_prompt_embeds": neg_embeds, + "negative_prompt_attention_mask": neg_mask, } + if self.config.validation_guidance_real is not None and self.config.validation_guidance_real > 1.0: + result["guidance_scale_real"] = float(self.config.validation_guidance_real) + + return result + def get_lora_target_layers(self): if self.config.lora_type.lower() == "standard": if getattr(self.config, "flux_lora_target", None) is None: @@ -326,6 +348,24 @@ def model_predict(self, prepared_batch): if attention_mask.dim() == 3 and attention_mask.size(1) == 1: attention_mask = attention_mask.squeeze(1) + # Match pipeline behaviour by extending the attention mask to cover image tokens + seq_length = packed_noisy_latents.shape[1] + attention_mask = attention_mask.to(device=self.accelerator.device) + if attention_mask.dtype != torch.bool: + attention_mask = attention_mask > 0 + attention_mask = torch.cat( + [ + attention_mask, + torch.ones( + attention_mask.shape[0], + seq_length, + device=attention_mask.device, + dtype=attention_mask.dtype, + ), + ], + dim=1, + ) + transformer_kwargs = { "hidden_states": packed_noisy_latents, "timestep": timesteps, @@ -420,6 +460,23 @@ def controlnet_predict(self, prepared_batch: dict) -> dict: if attention_mask.dim() == 3 and attention_mask.size(1) == 1: attention_mask = attention_mask.squeeze(1) + seq_length = packed_noisy_latents.shape[1] + attention_mask = attention_mask.to(device=self.accelerator.device) + if attention_mask.dtype != torch.bool: + attention_mask = attention_mask > 0 + attention_mask = torch.cat( + [ + attention_mask, + torch.ones( + attention_mask.shape[0], + seq_length, + device=attention_mask.device, + dtype=attention_mask.dtype, + ), + ], + dim=1, + ) + conditioning_scale = getattr(self.config, "controlnet_conditioning_scale", 1.0) controlnet_block_samples, controlnet_single_block_samples = self.controlnet( diff --git a/simpletuner/helpers/models/chroma/transformer.py b/simpletuner/helpers/models/chroma/transformer.py index 1dcc01214..403b529f4 100644 --- a/simpletuner/helpers/models/chroma/transformer.py +++ b/simpletuner/helpers/models/chroma/transformer.py @@ -6,7 +6,7 @@ import numpy as np import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import AttentionMixin, FeedForward from diffusers.models.cache_utils import CacheMixin @@ -24,6 +24,27 @@ from simpletuner.helpers.training.tread import TREADRouter +def adjust_rotary_embedding_dim( + rotary_emb: Tuple[torch.Tensor, torch.Tensor], + target_dim: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Ensure rotary embedding tensors match the expected head dimension.""" + if target_dim <= 0: + return rotary_emb + cos, sin = rotary_emb + current_dim = cos.shape[-1] + if current_dim == target_dim: + return cos, sin + if current_dim > target_dim: + return cos[..., :target_dim], sin[..., :target_dim] + + pad = target_dim - current_dim + pad_shape = cos.shape[:-1] + (pad,) + cos_padded = torch.cat([cos, cos.new_zeros(pad_shape)], dim=-1) + sin_padded = torch.cat([sin, sin.new_zeros(pad_shape)], dim=-1) + return cos_padded, sin_padded + + class ChromaAdaLayerNormZeroPruned(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). @@ -361,7 +382,6 @@ class ChromaTransformer2DModel( _tread_router: Optional[TREADRouter] = None _tread_routes: Optional[List[Dict[str, Any]]] = None - @register_to_config def __init__( self, patch_size: int = 1, @@ -378,7 +398,22 @@ def __init__( approximator_layers: int = 5, ): super().__init__() - self.out_channels = out_channels or in_channels + effective_out_channels = out_channels or in_channels + self.register_to_config( + patch_size=patch_size, + in_channels=in_channels, + out_channels=effective_out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + axes_dims_rope=axes_dims_rope, + approximator_num_channels=approximator_num_channels, + approximator_hidden_dim=approximator_hidden_dim, + approximator_layers=approximator_layers, + ) + self.out_channels = effective_out_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) @@ -509,6 +544,9 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) + target_rotary_dim = getattr(self.config, "attention_head_dim", None) + if target_rotary_dim is not None: + image_rotary_emb = adjust_rotary_embedding_dim(image_rotary_emb, int(target_rotary_dim)) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/simpletuner/helpers/models/common.py b/simpletuner/helpers/models/common.py index 8c3da1984..12f31ca9f 100644 --- a/simpletuner/helpers/models/common.py +++ b/simpletuner/helpers/models/common.py @@ -26,6 +26,7 @@ from simpletuner.helpers.training.min_snr_gamma import compute_snr from simpletuner.helpers.training.multi_process import _get_rank from simpletuner.helpers.training.wrappers import unwrap_model +from simpletuner.helpers.utils.offloading import enable_group_offload_on_components logger = logging.getLogger(__name__) from simpletuner.helpers.training.multi_process import should_log @@ -66,6 +67,7 @@ def get_model_config_path(model_family: str, model_path: str): class PipelineTypes(Enum): IMG2IMG = "img2img" TEXT2IMG = "text2img" + IMG2VIDEO = "img2video" CONTROLNET = "controlnet" CONTROL = "control" @@ -97,6 +99,45 @@ class ModelTypes(Enum): TEXT_ENCODER = "text_encoder" +class PipelineConditioningImageEmbedder: + """Wraps a Diffusers pipeline to expose a simple conditioning image encode interface.""" + + def __init__(self, pipeline, image_encoder, image_processor, device=None, weight_dtype=None): + if image_encoder is None or image_processor is None: + raise ValueError("PipelineConditioningImageEmbedder requires both an image encoder and image processor.") + self.pipeline = pipeline + self.image_encoder = image_encoder + self.image_processor = image_processor + self.device = device if device is not None else torch.device("cpu") + if isinstance(weight_dtype, str): + weight_dtype = getattr(torch, weight_dtype, None) + self.weight_dtype = weight_dtype + + if self.weight_dtype is not None: + self.image_encoder.to(self.device, dtype=self.weight_dtype) + else: + self.image_encoder.to(self.device) + self.image_encoder.eval() + + def encode(self, images): + inputs = self.image_processor(images=images, return_tensors="pt") + inputs = {k: v.to(self.device) for k, v in inputs.items()} + outputs = self.image_encoder(**inputs, output_hidden_states=True) + embeddings = None + hidden_states = getattr(outputs, "hidden_states", None) + if isinstance(hidden_states, (list, tuple)) and len(hidden_states) > 1: + embeddings = hidden_states[-2] + elif getattr(outputs, "last_hidden_state", None) is not None: + embeddings = outputs.last_hidden_state + elif torch.is_tensor(outputs): + embeddings = outputs + if embeddings is None: + raise ValueError("Image encoder did not return hidden states suitable for conditioning embeds.") + if self.weight_dtype is not None: + embeddings = embeddings.to(self.weight_dtype) + return embeddings + + class ModelFoundation(ABC): """ Base class that contains all the universal logic: @@ -181,6 +222,9 @@ def requires_conditioning_dataset(self) -> bool: def requires_conditioning_latents(self) -> bool: return False + def requires_conditioning_image_embeds(self) -> bool: + return False + def requires_validation_edit_captions(self) -> bool: """ Some edit / in-painting models want the *reference* image plus the @@ -516,6 +560,13 @@ def pre_vae_encode_transform_sample(self, sample): """ return sample + def encode_with_vae(self, vae, samples): + """ + Hook for models to customize VAE encoding behaviour (e.g. applying flavour-specific patches). + By default this simply forwards to the provided VAE. + """ + return vae.encode(samples) + def post_vae_encode_transform_sample(self, sample): """ Post-encode transform for the sample after passing it to the VAE. @@ -801,18 +852,16 @@ def _load_pipeline(self, pipeline_type: str = PipelineTypes.TEXT2IMG, load_base_ """ active_pipelines = getattr(self, "pipelines", {}) if pipeline_type in active_pipelines: + pipeline_instance = active_pipelines[pipeline_type] setattr( - active_pipelines[pipeline_type], + pipeline_instance, self.MODEL_TYPE.value, self.unwrap_model(model=self.model), ) if self.config.controlnet: - setattr( - active_pipelines[pipeline_type], - "controlnet", - self.unwrap_model(model=self.controlnet), - ) - return active_pipelines[pipeline_type] + setattr(pipeline_instance, "controlnet", self.unwrap_model(model=self.controlnet)) + self._configure_pipeline_offloading(pipeline_instance) + return pipeline_instance pipeline_kwargs = { "pretrained_model_name_or_path": self._model_config_path(), @@ -858,12 +907,240 @@ def _load_pipeline(self, pipeline_type: str = PipelineTypes.TEXT2IMG, load_base_ if self.config.controlnet and pipeline_type is PipelineTypes.CONTROLNET: pipeline_kwargs["controlnet"] = self.controlnet + optional_components = getattr(pipeline_class, "_optional_components", []) + require_conditioning_components = bool(self.requires_conditioning_image_embeds()) + if ( + "image_encoder" in optional_components + and "image_encoder" not in pipeline_kwargs + and getattr(self, "config", None) is not None + ): + repo_id = ( + getattr( + self.config, + "image_encoder_pretrained_model_name_or_path", + None, + ) + or self._model_config_path() + ) + processor_repo_id = ( + getattr( + self.config, + "image_processor_pretrained_model_name_or_path", + None, + ) + or repo_id + ) + explicit_encoder_source = getattr(self.config, "image_encoder_pretrained_model_name_or_path", None) + explicit_processor_source = getattr(self.config, "image_processor_pretrained_model_name_or_path", None) + image_encoder = None + image_processor = None + try: + from transformers import CLIPImageProcessor, CLIPVisionModel # type: ignore + except Exception as exc: # pragma: no cover - optional dependency guard + raise ValueError( + "Model requires conditioning image embeds but transformers is unavailable " + "to load the image encoder components." + ) from exc + + def _dedupe_subfolders(values): + seen = set() + result = [] + for value in values: + if not value or value in seen: + continue + seen.add(value) + result.append(value) + return result + + encoder_subfolders = [] + config_encoder_subfolder = getattr(self.config, "image_encoder_subfolder", None) + if isinstance(config_encoder_subfolder, (list, tuple, set)): + encoder_subfolders.extend(config_encoder_subfolder) + elif config_encoder_subfolder: + encoder_subfolders.append(config_encoder_subfolder) + encoder_subfolders.extend(("image_encoder", "vision_encoder")) + encoder_subfolders = _dedupe_subfolders(encoder_subfolders) + + loader_errors: list[tuple[str, Exception]] = [] + encoder_revision = getattr(self.config, "image_encoder_revision", getattr(self.config, "revision", None)) + for subfolder in encoder_subfolders: + try: + image_encoder = CLIPVisionModel.from_pretrained( + repo_id, + subfolder=subfolder, + use_safetensors=True, + revision=encoder_revision, + ) + break + except Exception as exc: # pragma: no cover - defensive + loader_errors.append((subfolder, exc)) + if image_encoder is None: + loader_error_text = ( + ", ".join(f"{repo_id}/{subfolder}: {error}" for subfolder, error in loader_errors) + if loader_errors + else "no matching subfolders were found." + ) + message = ( + "Unable to automatically load image encoder required for conditioning embeddings from " + f"'{repo_id}'. Attempts failed with: {loader_error_text}" + ) + if explicit_encoder_source: + raise ValueError(message) from (loader_errors[-1][1] if loader_errors else None) + log_fn = logger.warning if require_conditioning_components else logger.debug + log_fn( + "%s Set `image_encoder_pretrained_model_name_or_path` (and optionally " + "`image_encoder_subfolder`) in your config to provide the weights manually.", + message, + ) + else: + pipeline_kwargs["image_encoder"] = image_encoder + + processor_errors: list[tuple[str, Exception]] = [] + processor_subfolders = [] + config_processor_subfolder = getattr(self.config, "image_processor_subfolder", None) + if isinstance(config_processor_subfolder, (list, tuple, set)): + processor_subfolders.extend(config_processor_subfolder) + elif config_processor_subfolder: + processor_subfolders.append(config_processor_subfolder) + processor_subfolders.extend(("image_processor", "feature_extractor")) + processor_subfolders = _dedupe_subfolders(processor_subfolders) + processor_revision = getattr(self.config, "image_processor_revision", getattr(self.config, "revision", None)) + for subfolder in processor_subfolders: + try: + image_processor = CLIPImageProcessor.from_pretrained( + processor_repo_id, + subfolder=subfolder, + revision=processor_revision, + ) + break + except Exception as exc: # pragma: no cover - defensive + processor_errors.append((subfolder, exc)) + if image_processor is None: + processor_error_text = ( + ", ".join(f"{processor_repo_id}/{subfolder}: {error}" for subfolder, error in processor_errors) + if processor_errors + else "no matching subfolders were found." + ) + message = ( + "Unable to automatically load image processor required for conditioning embeddings from " + f"'{processor_repo_id}'. Attempts failed with: {processor_error_text}" + ) + if explicit_processor_source: + raise ValueError(message) from (processor_errors[-1][1] if processor_errors else None) + log_fn = logger.warning if require_conditioning_components else logger.debug + log_fn( + "%s Set `image_processor_pretrained_model_name_or_path` (and optionally " + "`image_processor_subfolder`) in your config to provide the processor configuration.", + message, + ) + else: + pipeline_kwargs["image_processor"] = image_processor + logger.debug(f"Initialising {pipeline_class.__name__} with components: {pipeline_kwargs}") - self.pipelines[pipeline_type] = pipeline_class.from_pretrained( - **pipeline_kwargs, - ) + try: + pipeline_instance = pipeline_class.from_pretrained(**pipeline_kwargs) + except (OSError, EnvironmentError, ValueError) as exc: + alt_repo = getattr(self.config, "pretrained_model_name_or_path", None) + current_repo = pipeline_kwargs.get("pretrained_model_name_or_path") + if alt_repo and isinstance(alt_repo, str) and alt_repo != current_repo: + logger.warning( + "Pipeline load failed from resolved config path '%s' (%s). Retrying with repository id '%s'.", + current_repo, + exc, + alt_repo, + ) + alt_kwargs = dict(pipeline_kwargs) + alt_kwargs["pretrained_model_name_or_path"] = alt_repo + pipeline_instance = pipeline_class.from_pretrained(**alt_kwargs) + else: + raise + self.pipelines[pipeline_type] = pipeline_instance + self._configure_pipeline_offloading(pipeline_instance) + + return pipeline_instance + + def get_conditioning_image_embedder(self): + """Return an adapter capable of encoding conditioning images, or None if unavailable.""" + if not self.requires_conditioning_image_embeds(): + return None + + return self._get_conditioning_image_embedder() + + def _get_conditioning_image_embedder(self): + """Subclass hook for providing conditioning image embedder (default: unsupported).""" + return None + + def get_group_offload_components(self, pipeline: DiffusionPipeline): + """ + Return the component mapping used for group offloading. + Sub-classes can override to prune or extend the mapping. + """ + return getattr(pipeline, "components", {}) + + def get_group_offload_exclusions(self, pipeline: DiffusionPipeline): + """ + Names of components that should be excluded from group offloading. + """ + return () - return self.pipelines[pipeline_type] + def _resolve_group_offload_device(self, pipeline: DiffusionPipeline): + pipeline_device = getattr(pipeline, "device", None) + if pipeline_device is not None: + return torch.device(pipeline_device) + if hasattr(self.accelerator, "device"): + return torch.device(self.accelerator.device) + return torch.device("cpu") + + def _resolve_group_offload_disk_path(self): + raw_path = getattr(self.config, "group_offload_to_disk_path", None) + if not raw_path: + return None + expanded = os.path.expanduser(raw_path) + return expanded + + def _configure_pipeline_offloading(self, pipeline: DiffusionPipeline): + if pipeline is None: + return + + enable_group_offload = bool(getattr(self.config, "enable_group_offload", False)) + enable_model_cpu_offload = bool(getattr(self.config, "enable_model_cpu_offload", False)) + + if enable_group_offload and enable_model_cpu_offload: + logger.warning( + "Both group offload and model CPU offload requested; prioritising group offload. " + "Disable one of the options to silence this warning." + ) + + if enable_group_offload: + try: + device = self._resolve_group_offload_device(pipeline) + use_stream = bool(getattr(self.config, "group_offload_use_stream", False)) + if use_stream: + if device.type != "cuda" or not torch.cuda.is_available(): + use_stream = False + enable_group_offload_on_components( + self.get_group_offload_components(pipeline), + device=device, + offload_type=getattr(self.config, "group_offload_type", "block_level"), + number_blocks_per_group=getattr(self.config, "group_offload_blocks_per_group", 1), + use_stream=use_stream, + offload_to_disk_path=self._resolve_group_offload_disk_path(), + exclude=self.get_group_offload_exclusions(pipeline), + ) + logger.info("Group offloading enabled for pipeline components.") + except ImportError as error: + logger.warning("Group offloading unavailable: %s", error) + except ValueError as error: + logger.warning("Group offloading validation error: %s", error) + except Exception as error: + logger.warning("Failed to configure group offloading: %s", error) + return + + if enable_model_cpu_offload and hasattr(pipeline, "enable_model_cpu_offload"): + try: + pipeline.enable_model_cpu_offload() + except RuntimeError as error: + logger.warning("Model CPU offload unavailable: %s", error) def get_pipeline(self, pipeline_type: str = PipelineTypes.TEXT2IMG, load_base_model: bool = True) -> DiffusionPipeline: possibly_cached_pipeline = self._load_pipeline(pipeline_type, load_base_model) @@ -971,6 +1248,8 @@ def prepare_batch_conditions(self, batch: dict, state: dict) -> dict: batch["conditioning_pixel_values"] = batch["conditioning_pixel_values"][0] if isinstance(batch.get("conditioning_latents"), list) and len(batch["conditioning_latents"]) > 0: batch["conditioning_latents"] = batch["conditioning_latents"][0] + if isinstance(batch.get("conditioning_image_embeds"), list) and len(batch["conditioning_image_embeds"]) > 0: + batch["conditioning_image_embeds"] = batch["conditioning_image_embeds"][0] return batch def prepare_batch(self, batch: dict, state: dict) -> dict: @@ -1017,6 +1296,10 @@ def prepare_batch(self, batch: dict, state: dict) -> dict: if encoder_attention_mask is not None and hasattr(encoder_attention_mask, "to"): batch["encoder_attention_mask"] = encoder_attention_mask.to(**target_device_kwargs) + conditioning_image_embeds = batch.get("conditioning_image_embeds") + if conditioning_image_embeds is not None and hasattr(conditioning_image_embeds, "to"): + batch["conditioning_image_embeds"] = conditioning_image_embeds.to(**target_device_kwargs) + # Sample noise noise = torch.randn_like(batch["latents"]) bsz = batch["latents"].shape[0] diff --git a/simpletuner/helpers/models/cosmos/model.py b/simpletuner/helpers/models/cosmos/model.py index 9beef0408..39b594d58 100644 --- a/simpletuner/helpers/models/cosmos/model.py +++ b/simpletuner/helpers/models/cosmos/model.py @@ -130,6 +130,12 @@ def _encode_prompts(self, prompts: list, is_negative_prompt: bool = False): return prompt_embeds + def get_group_offload_components(self, pipeline): + components = dict(super().get_group_offload_components(pipeline)) + if "transformer" not in components and getattr(self, "model", None) is not None: + components["transformer"] = self.unwrap_model(self.model) + return components + def pre_vae_encode_transform_sample(self, sample): """ We have to boost the thing from image to video w/ single frame. diff --git a/simpletuner/helpers/models/cosmos/transformer.py b/simpletuner/helpers/models/cosmos/transformer.py index 027d1cb16..477c54eec 100644 --- a/simpletuner/helpers/models/cosmos/transformer.py +++ b/simpletuner/helpers/models/cosmos/transformer.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin from diffusers.loaders.peft import PeftAdapterMixin from diffusers.models.attention import FeedForward @@ -495,7 +495,6 @@ class CosmosTransformer3DModel(PatchableModule, ModelMixin, ConfigMixin, FromOri _no_split_modules = ["CosmosTransformerBlock"] _keep_in_fp32_modules = ["learnable_pos_embed"] - @register_to_config def __init__( self, in_channels: int = 16, @@ -513,6 +512,21 @@ def __init__( extra_pos_embed_type: Optional[str] = "learnable", ) -> None: super().__init__() + self.register_to_config( + in_channels=in_channels, + out_channels=out_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_ratio=mlp_ratio, + text_embed_dim=text_embed_dim, + adaln_lora_dim=adaln_lora_dim, + max_size=max_size, + patch_size=patch_size, + rope_scale=rope_scale, + concat_padding_mask=concat_padding_mask, + extra_pos_embed_type=extra_pos_embed_type, + ) hidden_size = num_attention_heads * attention_head_dim # 1. Patch Embedding diff --git a/simpletuner/helpers/models/flux/model.py b/simpletuner/helpers/models/flux/model.py index 83920b978..48407d367 100644 --- a/simpletuner/helpers/models/flux/model.py +++ b/simpletuner/helpers/models/flux/model.py @@ -228,6 +228,14 @@ def requires_conditioning_validation_inputs(self) -> bool: return True return False + def get_group_offload_components(self, pipeline): + components = dict(super().get_group_offload_components(pipeline)) + if "transformer" not in components and getattr(self, "model", None) is not None: + components["transformer"] = self.unwrap_model(self.model) + if self.config.controlnet and "controlnet" not in components and getattr(self, "controlnet", None) is not None: + components["controlnet"] = self.unwrap_model(self.controlnet) + return components + def _format_text_embedding(self, text_embedding: torch.Tensor): """ Models can optionally format the stored text embedding, eg. in a dict, or diff --git a/simpletuner/helpers/models/flux/transformer.py b/simpletuner/helpers/models/flux/transformer.py index e82fd5f22..d97bc5b25 100644 --- a/simpletuner/helpers/models/flux/transformer.py +++ b/simpletuner/helpers/models/flux/transformer.py @@ -10,7 +10,7 @@ import torch.nn as nn import torch.nn.functional as F from diffusers import FluxTransformer2DModel as OriginalFluxTransformer2DModel -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention, AttentionProcessor @@ -396,7 +396,6 @@ class FluxTransformer2DModel(PatchableModule, ModelMixin, ConfigMixin, PeftAdapt _tread_router: Optional[TREADRouter] = None _tread_routes: Optional[List[Dict[str, Any]]] = None - @register_to_config def __init__( self, patch_size: int = 1, @@ -411,6 +410,18 @@ def __init__( axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() + self.register_to_config( + patch_size=patch_size, + in_channels=in_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + pooled_projection_dim=pooled_projection_dim, + guidance_embeds=guidance_embeds, + axes_dims_rope=axes_dims_rope, + ) self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim diff --git a/simpletuner/helpers/models/hidream/controlnet.py b/simpletuner/helpers/models/hidream/controlnet.py index 89472912c..9fe00830a 100644 --- a/simpletuner/helpers/models/hidream/controlnet.py +++ b/simpletuner/helpers/models/hidream/controlnet.py @@ -8,7 +8,7 @@ import PIL.Image import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, PeftAdapterMixin from diffusers.models.autoencoders import AutoencoderKL @@ -69,7 +69,6 @@ class HiDreamControlNetOutput(BaseOutput): class HiDreamControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True - @register_to_config def __init__( self, patch_size: int = 2, @@ -90,6 +89,24 @@ def __init__( ): super().__init__() + self.register_to_config( + patch_size=patch_size, + in_channels=in_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + pooled_projection_dim=pooled_projection_dim, + guidance_embeds=guidance_embeds, + max_seq_length=max_seq_length, + conditioning_embedding_channels=conditioning_embedding_channels, + axes_dims_rope=axes_dims_rope, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + aux_loss_alpha=aux_loss_alpha, + ) + self.inner_dim = num_attention_heads * attention_head_dim self.max_seq = max_seq_length self.config.patch_size = patch_size diff --git a/simpletuner/helpers/models/hidream/schedule.py b/simpletuner/helpers/models/hidream/schedule.py index de27db2ec..87cabaaa4 100644 --- a/simpletuner/helpers/models/hidream/schedule.py +++ b/simpletuner/helpers/models/hidream/schedule.py @@ -7,7 +7,7 @@ import numpy as np import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from diffusers.utils import deprecate, is_scipy_available @@ -71,7 +71,6 @@ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 - @register_to_config def __init__( self, num_train_timesteps: int = 1000, @@ -91,6 +90,25 @@ def __init__( steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" ): + super().__init__() + self.register_to_config( + num_train_timesteps=num_train_timesteps, + solver_order=solver_order, + prediction_type=prediction_type, + shift=shift, + use_dynamic_shifting=use_dynamic_shifting, + thresholding=thresholding, + dynamic_thresholding_ratio=dynamic_thresholding_ratio, + sample_max_value=sample_max_value, + predict_x0=predict_x0, + solver_type=solver_type, + lower_order_final=lower_order_final, + disable_corrector=disable_corrector, + solver_p=solver_p, + timestep_spacing=timestep_spacing, + steps_offset=steps_offset, + final_sigmas_type=final_sigmas_type, + ) if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: diff --git a/simpletuner/helpers/models/hidream/transformer.py b/simpletuner/helpers/models/hidream/transformer.py index e7d1300fa..4d4ea1c24 100644 --- a/simpletuner/helpers/models/hidream/transformer.py +++ b/simpletuner/helpers/models/hidream/transformer.py @@ -4,7 +4,7 @@ import numpy as np import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin @@ -839,7 +839,6 @@ class HiDreamImageTransformer2DModel(PatchableModule, ModelMixin, ConfigMixin, P _supports_gradient_checkpointing = True _no_split_modules = ["HiDreamImageBlock"] - @register_to_config def __init__( self, patch_size: Optional[int] = None, @@ -859,7 +858,25 @@ def __init__( aux_loss_alpha: float = 0.0, ): super().__init__() - self.out_channels = out_channels or in_channels + effective_out_channels = out_channels or in_channels + self.register_to_config( + patch_size=patch_size, + in_channels=in_channels, + out_channels=effective_out_channels, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + text_emb_dim=text_emb_dim, + num_routed_experts=num_routed_experts, + num_activated_experts=num_activated_experts, + axes_dims_rope=axes_dims_rope, + max_resolution=max_resolution, + llama_layers=llama_layers, + aux_loss_alpha=aux_loss_alpha, + ) + self.out_channels = effective_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.llama_layers = llama_layers diff --git a/simpletuner/helpers/models/kolors/controlnet.py b/simpletuner/helpers/models/kolors/controlnet.py index 1526bc253..c6870e50f 100644 --- a/simpletuner/helpers/models/kolors/controlnet.py +++ b/simpletuner/helpers/models/kolors/controlnet.py @@ -15,7 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -200,7 +200,6 @@ class conditioning with `class_embed_type` equal to `None`. _supports_gradient_checkpointing = True - @register_to_config def __init__( self, in_channels: int = 4, @@ -269,6 +268,42 @@ def __init__( if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + transformer_layers_per_block_tuple = tuple(transformer_layers_per_block) + + self.register_to_config( + in_channels=in_channels, + conditioning_channels=conditioning_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + down_block_types=down_block_types, + mid_block_type=mid_block_type, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + downsample_padding=downsample_padding, + mid_block_scale_factor=mid_block_scale_factor, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block_tuple, + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + use_linear_projection=use_linear_projection, + class_embed_type=class_embed_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + num_class_embeds=num_class_embeds, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + global_pool_conditions=global_pool_conditions, + addition_embed_type_num_heads=addition_embed_type_num_heads, + ) # input conv_in_kernel = 3 diff --git a/simpletuner/helpers/models/ltxvideo/autoencoder.py b/simpletuner/helpers/models/ltxvideo/autoencoder.py index 6b4744719..aaca90d65 100644 --- a/simpletuner/helpers/models/ltxvideo/autoencoder.py +++ b/simpletuner/helpers/models/ltxvideo/autoencoder.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin from diffusers.models.activations import get_activation from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution @@ -1079,7 +1079,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True - @register_to_config def __init__( self, in_channels: int = 3, @@ -1112,6 +1111,31 @@ def __init__( temporal_compression_ratio: int = None, ) -> None: super().__init__() + self.register_to_config( + in_channels=in_channels, + out_channels=out_channels, + latent_channels=latent_channels, + block_out_channels=block_out_channels, + down_block_types=down_block_types, + decoder_block_out_channels=decoder_block_out_channels, + layers_per_block=layers_per_block, + decoder_layers_per_block=decoder_layers_per_block, + spatio_temporal_scaling=spatio_temporal_scaling, + decoder_spatio_temporal_scaling=decoder_spatio_temporal_scaling, + decoder_inject_noise=decoder_inject_noise, + downsample_type=downsample_type, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, + timestep_conditioning=timestep_conditioning, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + scaling_factor=scaling_factor, + encoder_causal=encoder_causal, + decoder_causal=decoder_causal, + spatial_compression_ratio=spatial_compression_ratio, + temporal_compression_ratio=temporal_compression_ratio, + ) self.encoder = LTXVideoEncoder3d( in_channels=in_channels, diff --git a/simpletuner/helpers/models/ltxvideo/model.py b/simpletuner/helpers/models/ltxvideo/model.py index e3e49a327..35d47fd1e 100644 --- a/simpletuner/helpers/models/ltxvideo/model.py +++ b/simpletuner/helpers/models/ltxvideo/model.py @@ -160,6 +160,12 @@ def _encode_prompts(self, prompts: list, is_negative_prompt: bool = False): negative_prompt_attention_mask, ) + def get_group_offload_components(self, pipeline): + components = dict(super().get_group_offload_components(pipeline)) + if "transformer" not in components and getattr(self, "model", None) is not None: + components["transformer"] = self.unwrap_model(self.model) + return components + def model_predict(self, prepared_batch): if prepared_batch["noisy_latents"].shape[1] != 128: raise ValueError( diff --git a/simpletuner/helpers/models/ltxvideo/transformer.py b/simpletuner/helpers/models/ltxvideo/transformer.py index 074a2a377..bbc63938d 100644 --- a/simpletuner/helpers/models/ltxvideo/transformer.py +++ b/simpletuner/helpers/models/ltxvideo/transformer.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward from diffusers.models.attention_dispatch import dispatch_attention_fn @@ -410,7 +410,6 @@ class LTXVideoTransformer3DModel( _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["LTXVideoTransformerBlock"] - @register_to_config def __init__( self, in_channels: int = 128, @@ -431,7 +430,26 @@ def __init__( ) -> None: super().__init__() - out_channels = out_channels or in_channels + effective_out_channels = out_channels or in_channels + self.register_to_config( + in_channels=in_channels, + out_channels=effective_out_channels, + patch_size=patch_size, + patch_size_t=patch_size_t, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + num_layers=num_layers, + activation_fn=activation_fn, + qk_norm=qk_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + caption_channels=caption_channels, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + ) + + out_channels = effective_out_channels inner_dim = num_attention_heads * attention_head_dim self.proj_in = nn.Linear(in_channels, inner_dim) diff --git a/simpletuner/helpers/models/pixart/controlnet.py b/simpletuner/helpers/models/pixart/controlnet.py index 617fd8f39..919fbe812 100644 --- a/simpletuner/helpers/models/pixart/controlnet.py +++ b/simpletuner/helpers/models/pixart/controlnet.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Optional import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import PeftAdapterMixin from diffusers.models import PixArtTransformer2DModel from diffusers.models.attention import BasicTransformerBlock @@ -100,7 +100,6 @@ def forward( class PixArtSigmaControlNetAdapterModel(ModelMixin, ConfigMixin, PeftAdapterMixin): # N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer - @register_to_config def __init__( self, num_layers: int = 13, @@ -111,6 +110,13 @@ def __init__( ) -> None: super().__init__() + self.register_to_config( + num_layers=num_layers, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + self.num_layers = num_layers self.num_attention_heads = num_attention_heads self.attention_head_dim = attention_head_dim diff --git a/simpletuner/helpers/models/pixart/transformer.py b/simpletuner/helpers/models/pixart/transformer.py index 237a5d3cf..daf72003c 100644 --- a/simpletuner/helpers/models/pixart/transformer.py +++ b/simpletuner/helpers/models/pixart/transformer.py @@ -15,7 +15,7 @@ import numpy as np import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import PeftAdapterMixin from diffusers.models.attention import BasicTransformerBlock from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 @@ -86,7 +86,6 @@ class PixArtTransformer2DModel(PatchableModule, ModelMixin, ConfigMixin, PeftAda _tread_router: Optional[TREADRouter] = None _tread_routes: Optional[List[Dict[str, Any]]] = None - @register_to_config def __init__( self, num_attention_heads: int = 16, @@ -123,15 +122,41 @@ def __init__( f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None." ) - # Set some common variables used across the board. - self.attention_head_dim = attention_head_dim - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim - self.out_channels = in_channels if out_channels is None else out_channels if use_additional_conditions is None: if sample_size == 128: use_additional_conditions = True else: use_additional_conditions = False + + effective_out_channels = in_channels if out_channels is None else out_channels + self.register_to_config( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=effective_out_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + patch_size=patch_size, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + interpolation_scale=interpolation_scale, + use_additional_conditions=use_additional_conditions, + caption_channels=caption_channels, + attention_type=attention_type, + ) + + # Set some common variables used across the board. + self.attention_head_dim = attention_head_dim + self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.out_channels = effective_out_channels self.use_additional_conditions = use_additional_conditions self.gradient_checkpointing = False diff --git a/simpletuner/helpers/models/qwen_image/model.py b/simpletuner/helpers/models/qwen_image/model.py index a77af23d6..17b082315 100644 --- a/simpletuner/helpers/models/qwen_image/model.py +++ b/simpletuner/helpers/models/qwen_image/model.py @@ -23,8 +23,6 @@ class QwenImage(ImageModelFoundation): NAME = "Qwen-Image" - MODEL_DESCRIPTION = "Qwen's multimodal image generation model" - ENABLED_IN_WIZARD = True PREDICTION_TYPE = PredictionTypes.FLOW_MATCHING MODEL_TYPE = ModelTypes.TRANSFORMER AUTOENCODER_CLASS = AutoencoderKLQwenImage @@ -63,7 +61,9 @@ def __init__(self, config: dict, accelerator): self.vae_scale_factor = 8 def setup_training_noise_schedule(self): - # load flow matching scheduler for qwen image + """ + Loads the noise schedule for Qwen Image (flow matching). + """ from diffusers import FlowMatchEulerDiscreteScheduler scheduler_config = { @@ -89,12 +89,23 @@ def setup_training_noise_schedule(self): return self.config, self.noise_schedule def _encode_prompts(self, prompts: list, is_negative_prompt: bool = False): + """ + Encode prompts using Qwen's text encoder. + + Args: + prompts: List of text prompts to encode + is_negative_prompt: Whether these are negative prompts + + Returns: + Tuple of (prompt_embeds, prompt_embeds_mask) + """ if self.text_encoders is None or len(self.text_encoders) == 0: self.load_text_encoder() text_encoder = self.text_encoders[0] tokenizer = self.tokenizers[0] + # Move to device if needed if text_encoder.device != self.accelerator.device: text_encoder.to(self.accelerator.device) @@ -111,6 +122,15 @@ def _encode_prompts(self, prompts: list, is_negative_prompt: bool = False): return prompt_embeds, prompt_embeds_mask def _format_text_embedding(self, text_embedding: torch.Tensor): + """ + Format the text embeddings for Qwen Image. + + Args: + text_embedding: The embedding tuple from _encode_prompts + + Returns: + Dictionary with formatted embeddings + """ prompt_embeds, prompt_embeds_mask = text_embedding return { @@ -119,6 +139,9 @@ def _format_text_embedding(self, text_embedding: torch.Tensor): } def convert_text_embed_for_pipeline(self, text_embedding: torch.Tensor) -> dict: + """ + Convert text embeddings for pipeline use. + """ attention_mask = text_embedding.get("attention_masks", None) if attention_mask is not None and attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0) @@ -133,6 +156,9 @@ def convert_text_embed_for_pipeline(self, text_embedding: torch.Tensor) -> dict: } def convert_negative_text_embed_for_pipeline(self, text_embedding: torch.Tensor, prompt: str) -> dict: + """ + Convert negative text embeddings for pipeline use. + """ attention_mask = text_embedding.get("attention_masks", None) if attention_mask is not None and attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0) @@ -147,6 +173,9 @@ def convert_negative_text_embed_for_pipeline(self, text_embedding: torch.Tensor, } def model_predict(self, prepared_batch): + """ + Perform a forward pass with the Qwen Image model. + """ latent_model_input = prepared_batch["noisy_latents"] timesteps = prepared_batch["timesteps"] @@ -157,52 +186,54 @@ def model_predict(self, prepared_batch): else: batch_size, num_channels, latent_height, latent_width = latent_model_input.shape - # get pipeline class for static methods + # Get the pipeline class to use its static methods pipeline_class = self.PIPELINE_CLASSES[PipelineTypes.TEXT2IMG] - # _unpack_latents expects pixel-space dims, converts latent->pixel + # Note: _unpack_latents expects pixel-space dimensions and will apply vae_scale_factor + # So we need to convert our latent dimensions back to pixel space pixel_height = latent_height * self.vae_scale_factor pixel_width = latent_width * self.vae_scale_factor - # pack latents - latent_model_input = pipeline_class._pack_latents( + # Pack latents using the official method + flat_latents = pipeline_class._pack_latents( latent_model_input, batch_size, num_channels, latent_height, latent_width, ) + latent_model_input = flat_latents - # prepare text embeddings + # Prepare text embeddings prompt_embeds = prepared_batch["prompt_embeds"].to( device=self.accelerator.device, dtype=self.config.weight_dtype, ) - # get attention mask + # Get attention mask prompt_embeds_mask = prepared_batch.get("encoder_attention_mask") if prompt_embeds_mask is not None: prompt_embeds_mask = prompt_embeds_mask.to(self.accelerator.device, dtype=torch.int64) if prompt_embeds_mask.dim() == 3 and prompt_embeds_mask.size(1) == 1: prompt_embeds_mask = prompt_embeds_mask.squeeze(1) - # image shapes for patchification (latent dims / 2) + # Prepare image shapes - using the LATENT dimensions divided by 2 (for patchification) img_shapes = [(1, latent_height // 2, latent_width // 2)] * batch_size - # normalize timesteps to [0,1] + # Prepare timesteps (normalize to 0-1 range) timesteps = ( torch.tensor(prepared_batch["timesteps"]).expand(batch_size).to(device=self.accelerator.device) / 1000.0 # Normalize to [0, 1] ) - # text sequence lengths + # Get text sequence lengths txt_seq_lens = ( prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else [prompt_embeds.shape[1]] * batch_size ) - # forward pass + # Forward pass through transformer noise_pred = self.model( hidden_states=latent_model_input.to(self.accelerator.device, self.config.weight_dtype), timestep=timesteps, @@ -214,25 +245,32 @@ def model_predict(self, prepared_batch): return_dict=False, )[0] - # unpack noise prediction if the transformer returned packed tokens - if noise_pred.dim() == 3: - noise_pred = pipeline_class._unpack_latents(noise_pred, pixel_height, pixel_width, self.vae_scale_factor) + # Unpack the noise prediction back to original shape + noise_pred = pipeline_class._unpack_latents(noise_pred, pixel_height, pixel_width, self.vae_scale_factor) - # remove extra dimension from _unpack_latents - if noise_pred.dim() == 5: - noise_pred = noise_pred.squeeze(2) # Remove the frame dimension + # Remove the extra dimension that _unpack_latents adds + if noise_pred.dim() == 5: + noise_pred = noise_pred.squeeze(2) # Remove the frame dimension return {"model_prediction": noise_pred} def pre_vae_encode_transform_sample(self, sample): - # qwen vae expects 5D input + """ + Pre-encode transform for the sample before passing it to the VAE. + Qwen Image VAE expects 5D input (adds frame dimension). + """ + # Add frame dimension for Qwen VAE if needed if sample.dim() == 4: sample = sample.unsqueeze(2) # (B, C, H, W) -> (B, C, 1, H, W) return sample def post_vae_encode_transform_sample(self, sample): - # normalize latents and remove frame dimension - # qwen vae normalization, remove frame dimension + """ + Post-encode transform for Qwen Image VAE output. + Normalizes latents and removes frame dimension. + """ + # Qwen Image VAE normalization + # Remove frame dimension if present sample_latents = sample.latent_dist.sample() if sample_latents.dim() == 5: sample_latents = sample_latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W) @@ -250,6 +288,9 @@ def post_vae_encode_transform_sample(self, sample): return sample_latents def check_user_config(self): + """ + Check and validate user configuration for Qwen Image. + """ super().check_user_config() # Qwen Image specific checks diff --git a/simpletuner/helpers/models/qwen_image/transformer.py b/simpletuner/helpers/models/qwen_image/transformer.py index 10a5c16ab..81000a6a8 100644 --- a/simpletuner/helpers/models/qwen_image/transformer.py +++ b/simpletuner/helpers/models/qwen_image/transformer.py @@ -1,4 +1,4 @@ -# Copyright 2025 Qwen-Image Team, The HuggingFace Team, and 2025 bghira. All rights reserved. +# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ import functools import math -from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -31,45 +30,14 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm -from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import maybe_allow_in_graph from simpletuner.helpers.training.tread import TREADRouter -from simpletuner.helpers.utils.patching import MutableModuleList, PatchableModule logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _enable_safe_half_full(): - if not torch.backends.mps.is_available(): - return - - if getattr(torch.full, "__wrapped_safe_half_full__", False): - return - - original_full = torch.full - - def safe_full(*args, **kwargs): - dtype = kwargs.get("dtype") - if dtype is torch.float16: - try: - return original_full(*args, **kwargs) - except RuntimeError as exc: # pragma: no cover - if "cannot be converted to type at::Half" in str(exc): - kwargs_fp32 = dict(kwargs) - kwargs_fp32["dtype"] = torch.float32 - result = original_full(*args, **kwargs_fp32) - return result.to(torch.float16) - raise - return original_full(*args, **kwargs) - - safe_full.__wrapped_safe_half_full__ = True - torch.full = safe_full - - -_enable_safe_half_full() - - def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, @@ -78,15 +46,36 @@ def get_timestep_embedding( scale: float = 1, max_period: int = 10000, ) -> torch.Tensor: - # sinusoidal timestep embeddings from DDPM + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + Args + timesteps (torch.Tensor): + a 1-D floating point Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + Scaling factor applied to the embeddings. + max_period (int): + Controls the maximum frequency of the embeddings + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. + """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + if not timesteps.is_floating_point(): + raise TypeError("`get_timestep_embedding` expects floating-point `timesteps`. Call `.float()` on the input.") half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] + timesteps = timesteps.float() + emb = timesteps[:, None] * emb[None, :] # scale embeddings emb = scale * emb @@ -110,10 +99,25 @@ def apply_rotary_emb_qwen( use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ if use_real: cos, sin = freqs_cis # [S, D] - cos = cos.unsqueeze(0).unsqueeze(2).to(device=x.device, dtype=x.dtype) - sin = sin.unsqueeze(0).unsqueeze(2).to(device=x.device, dtype=x.dtype) + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit @@ -126,74 +130,41 @@ def apply_rotary_emb_qwen( else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - out = (x * cos + x_rotated * sin).to(x.dtype) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - freq_shape = freqs_cis.shape[-1] - if freq_shape != x_rotated.shape[-1]: - freqs_cis = freqs_cis[..., : x_rotated.shape[-1]] - freqs_cis = freqs_cis.to(x_rotated.device) - freqs_cis = freqs_cis.unsqueeze(0).unsqueeze(2) + freqs_cis = freqs_cis.unsqueeze(1) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) -class QwenTimestepProjEmbeddings(PatchableModule): +class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.timestep_embedder.time_embed_dim = embedding_dim - self.time_embed_dim = embedding_dim - def forward(self, timestep, *states, guidance=None, hidden_states=None): + def forward(self, timestep, hidden_states): timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D) - target_tensor: Optional[torch.Tensor] = None - guidance_tensor = guidance if isinstance(guidance, torch.Tensor) else None - - collected_states: List[torch.Tensor] = list(states) - if hidden_states is not None: - collected_states.append(hidden_states) - - for state in collected_states: - if not isinstance(state, torch.Tensor): - continue - if state.dim() == 1 and guidance_tensor is None: - guidance_tensor = state - continue - target_tensor = state - break - - if target_tensor is None and guidance_tensor is not None: - target_tensor = guidance_tensor - - if target_tensor is None: - target_tensor = timesteps_emb - - conditioning = timesteps_emb.to(device=target_tensor.device, dtype=target_tensor.dtype) - - if guidance_tensor is not None: - guidance_embed = guidance_tensor.to(device=conditioning.device, dtype=conditioning.dtype) - guidance_embed = guidance_embed.unsqueeze(-1).expand_as(conditioning) - conditioning = conditioning + guidance_embed + conditioning = timesteps_emb return conditioning -class QwenEmbedRope(PatchableModule): - def __init__(self, theta: int, axes_dim: List[int], scale_rope: bool = False): +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim - self._current_max_len = 1024 - pos_index = torch.arange(self._current_max_len) - neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1 + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 self.pos_freqs = torch.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), @@ -210,79 +181,47 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope: bool = False): ], dim=1, ) + self.rope_cache = {} - # DO NOT USE REGISTER BUFFER HERE; COMPLEX NUMBERS MAY LOSE THEIR IMAGINARY PART + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ assert dim % 2 == 0 freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs - def _expand_pos_freqs_if_needed(self, required_len: int) -> None: - if required_len <= self._current_max_len: - return - - new_max_len = max(required_len, int((required_len + 511) // 512) * 512) - - if required_len > 512: - logger.warning( - "QwenImage model was trained on prompts up to 512 tokens. " - f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. " - "Consider using shorter prompts for better results." - ) - - device = self.pos_freqs.device - dtype = self.pos_freqs.dtype - pos_index = torch.arange(new_max_len, device=device) - neg_index = torch.arange(new_max_len, device=device).flip(0) * -1 - 1 - - self.pos_freqs = torch.cat( - [ - self.rope_params(pos_index, self.axes_dim[0], self.theta), - self.rope_params(pos_index, self.axes_dim[1], self.theta), - self.rope_params(pos_index, self.axes_dim[2], self.theta), - ], - dim=1, - ).to(device=device, dtype=dtype) - - self.neg_freqs = torch.cat( - [ - self.rope_params(neg_index, self.axes_dim[0], self.theta), - self.rope_params(neg_index, self.axes_dim[1], self.theta), - self.rope_params(neg_index, self.axes_dim[2], self.theta), - ], - dim=1, - ).to(device=device, dtype=dtype) - - self._current_max_len = new_max_len - self._compute_video_freqs.cache_clear() - def forward(self, video_fhw, txt_seq_lens, device): + """ + Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: + txt_length: [bs] a list of 1 integers representing the length of the text + """ if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) - # Normalise input so that we always iterate over a list of (frame, height, width) tuples - if isinstance(video_fhw, (tuple, list)): - # ``video_fhw`` can be provided either as a single triple or as a list of triples. - if len(video_fhw) == 0: - video_fhw = [] - elif isinstance(video_fhw[0], (list, tuple)) and len(video_fhw[0]) == 3: - video_fhw = [tuple(v) for v in video_fhw] - elif len(video_fhw) == 3: - video_fhw = [tuple(video_fhw)] - else: - video_fhw = [tuple(video_fhw)] - else: + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): video_fhw = [video_fhw] vid_freqs = [] max_vid_index = 0 for idx, fhw in enumerate(video_fhw): frame, height, width = fhw - video_freq = self._compute_video_freqs(frame, height, width, idx) + rope_key = f"{idx}_{height}_{width}" + + if not torch.compiler.is_compiling(): + if rope_key not in self.rope_cache: + self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) + video_freq = self.rope_cache[rope_key] + else: + video_freq = self._compute_video_freqs(frame, height, width, idx) video_freq = video_freq.to(device) vid_freqs.append(video_freq) @@ -292,14 +231,12 @@ def forward(self, video_fhw, txt_seq_lens, device): max_vid_index = max(height, width, max_vid_index) max_len = max(txt_seq_lens) - required_len = max_vid_index + max_len - self._expand_pos_freqs_if_needed(required_len) txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs - @functools.lru_cache(maxsize=128) + @functools.lru_cache(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) @@ -320,12 +257,16 @@ def _compute_video_freqs(self, frame, height, width, idx=0): class QwenDoubleStreamAttnProcessor2_0: - # joint attention for text and image streams + """ + Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor + implements joint attention computation where text and image streams are processed together. + """ _attention_backend = None def __init__(self): - if not hasattr(F, "scaled_dot_product_attention") or not callable(F.scaled_dot_product_attention): + sdpa = getattr(F, "scaled_dot_product_attention", None) + if sdpa is None or not callable(sdpa): raise ImportError( "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." ) @@ -417,7 +358,7 @@ def __call__( @maybe_allow_in_graph -class QwenImageTransformerBlock(PatchableModule): +class QwenImageTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 ): @@ -460,6 +401,7 @@ def __init__( self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" shift, scale, gate = mod_params.chunk(3, dim=-1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) @@ -473,15 +415,8 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # Get modulation parameters for both streams - if temb.shape[-1] == self.dim: - mod_dtype = self.img_mod[1].weight.dtype - img_mod_params = self.img_mod(temb.to(mod_dtype)).to(temb.dtype) - txt_mod_params = self.txt_mod(temb.to(mod_dtype)).to(temb.dtype) - elif temb.shape[-1] == 6 * self.dim: - img_mod_params = temb.to(hidden_states.dtype) - txt_mod_params = temb.to(encoder_hidden_states.dtype) - else: - raise ValueError(f"Expected modulation embedding of size {self.dim} or {6 * self.dim}, got {temb.shape[-1]}") + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] # Split modulation parameters for norm1 and norm2 img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] @@ -495,47 +430,43 @@ def forward( txt_normed = self.txt_norm1(encoder_hidden_states) txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) - # joint attention: compute QKV, apply norm/rope, concat, split - attn_inputs = { - "hidden_states": img_modulated, - "encoder_hidden_states": txt_modulated, - "encoder_hidden_states_mask": encoder_hidden_states_mask, - "image_rotary_emb": image_rotary_emb, - } - if joint_attention_kwargs: - attn_inputs.update(joint_attention_kwargs) - - attn_output = self.attn(**attn_inputs) - - if not hasattr(self.attn, "call_args"): - self.attn.call_args = SimpleNamespace(args=(), kwargs={k: v for k, v in attn_inputs.items()}) + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) - # attention processor returns (img_output, txt_output) + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided img_attn_output, txt_attn_output = attn_output - # apply gates and residual + # Apply attention gates and add residual (like in Megatron) hidden_states = hidden_states + img_gate1 * img_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output # Process image stream - norm2 + MLP img_normed2 = self.img_norm2(hidden_states) img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) - img_mlp_output = self.img_mlp(img_modulated2.to(torch.float32)).to(img_modulated2.dtype) + img_mlp_output = self.img_mlp(img_modulated2) hidden_states = hidden_states + img_gate2 * img_mlp_output # Process text stream - norm2 + MLP txt_normed2 = self.txt_norm2(encoder_hidden_states) txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) - txt_mlp_output = self.txt_mlp(txt_modulated2.to(torch.float32)).to(txt_modulated2.dtype) + txt_mlp_output = self.txt_mlp(txt_modulated2) encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output - # clip for fp16 overflow prevention - if torch.isnan(encoder_hidden_states).any() or torch.isinf(encoder_hidden_states).any(): - encoder_hidden_states = torch.nan_to_num(encoder_hidden_states, nan=0.0, posinf=65504, neginf=-65504) + # Clip to prevent overflow for fp16 if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - if torch.isnan(hidden_states).any() or torch.isinf(hidden_states).any(): - hidden_states = torch.nan_to_num(hidden_states, nan=0.0, posinf=65504, neginf=-65504) if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) @@ -543,9 +474,32 @@ def forward( class QwenImageTransformer2DModel( - PatchableModule, ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin ): - # qwen dual-stream transformer model + """ + The Transformer model introduced in Qwen. + + Args: + patch_size (`int`, defaults to `2`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `60`): + The number of layers of dual stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `3584`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ _supports_gradient_checkpointing = True _no_split_modules = ["QwenImageTransformerBlock"] @@ -576,10 +530,11 @@ def __init__( self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) self._patch_area = patch_size * patch_size - self.img_in = nn.Linear(in_channels, self.inner_dim) + self._img_in_features = in_channels + self.img_in = nn.Linear(self._img_in_features, self.inner_dim) self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) - self.transformer_blocks = MutableModuleList( + self.transformer_blocks = nn.ModuleList( [ QwenImageTransformerBlock( dim=self.inner_dim, @@ -600,71 +555,51 @@ def __init__( self._tread_routes = None def set_router(self, router: TREADRouter, routes: Optional[List[Dict]] = None): + """Set TREAD router and routes for token reduction during training.""" self._tread_router = router self._tread_routes = routes - def _flatten_image_latents(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, int, int]: - if hidden_states.ndim != 4: + def _tokenize_hidden_states(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Optional[int], Optional[int]]: + """ + Ensure hidden states are flattened into patch tokens. + + Returns the tokenized hidden states along with optional patch grid sizes. + """ + if hidden_states.ndim == 3: + # Already tokenized: (batch, tokens, features) return hidden_states, None, None + if hidden_states.ndim != 4: + raise ValueError(f"Expected hidden_states to be 3D tokens or 4D latent map, got shape {hidden_states.shape}.") + batch_size, channels, height, width = hidden_states.shape patch_size = self.config.patch_size if height % patch_size != 0 or width % patch_size != 0: raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size}).") - patches = torch.nn.functional.unfold( - hidden_states, - kernel_size=patch_size, - stride=patch_size, - ) - patches = patches.transpose(1, 2) - if patches.shape[-1] != self._img_in_features: + grid_h = height // patch_size + grid_w = width // patch_size + token_dim = channels * patch_size * patch_size + if token_dim != self._img_in_features: raise ValueError( - f"Flattened latent features ({patches.shape[-1]}) do not match expected in_channels " - f"({self._img_in_features}). Ensure `in_channels` equals latent_channels * patch_size^2." + f"Token dimension mismatch: expected {self._img_in_features}, " + f"but got {token_dim} from latents (channels={channels}, patch_size={patch_size})." ) - return patches, height // patch_size, width // patch_size - - def _unflatten_image_latents( - self, - hidden_states: torch.Tensor, - img_shapes: Optional[List[Tuple[int, int, int]]], - patch_grid: Tuple[int, int], - ) -> torch.Tensor: - if hidden_states.ndim != 3: - return hidden_states - - batch_size = hidden_states.shape[0] - if not img_shapes: - raise ValueError("img_shapes must be provided to reconstruct image latents.") - - if len(img_shapes) == 1 and batch_size > 1: - img_shapes = img_shapes * batch_size - - patch_size = self.config.patch_size - out_channels = self.out_channels - expected_features = self._patch_area * out_channels - if hidden_states.shape[-1] != expected_features: - raise ValueError(f"Expected last dimension to be {expected_features}, got {hidden_states.shape[-1]}.") - - outputs: List[torch.Tensor] = [] - patch_height, patch_width = patch_grid - - for idx, sample in enumerate(hidden_states): - frames, latent_h, latent_w = img_shapes[idx] - tokens_expected = frames * latent_h * latent_w - if sample.shape[0] != tokens_expected: - raise ValueError( - f"Token count mismatch for sample {idx}: expected {tokens_expected}, got {sample.shape[0]}." - ) - - sample = sample.view(frames, latent_h, latent_w, patch_size, patch_size, out_channels) - sample = sample.permute(0, 5, 1, 3, 2, 4) - sample = sample.reshape(frames * out_channels, latent_h * patch_size, latent_w * patch_size) - outputs.append(sample) - output = torch.stack(outputs, dim=0) - return output + hidden_states = hidden_states.view( + batch_size, + channels, + grid_h, + patch_size, + grid_w, + patch_size, + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).reshape( + batch_size, + grid_h * grid_w, + token_dim, + ) + return hidden_states, grid_h, grid_w def forward( self, @@ -680,6 +615,30 @@ def forward( force_keep_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -687,7 +646,7 @@ def forward( lora_scale = 1.0 if USE_PEFT_BACKEND: - # weight lora layers + # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: @@ -695,17 +654,13 @@ def forward( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - hidden_states, patch_h, patch_w = self._flatten_image_latents(hidden_states) + hidden_states, patch_h, patch_w = self._tokenize_hidden_states(hidden_states) if img_shapes is None: if patch_h is None or patch_w is None: - raise ValueError("img_shapes must be provided when hidden_states are already flattened.") + raise ValueError("img_shapes must be provided when hidden_states are already tokenized.") img_shapes = [(1, patch_h, patch_w)] * hidden_states.shape[0] - if patch_h is None or patch_w is None: - patch_h = img_shapes[0][1] - patch_w = img_shapes[0][2] - hidden_states = self.img_in(hidden_states) timestep = timestep.to(hidden_states.dtype) @@ -723,19 +678,19 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) - # tread routing setup + # TREAD initialization routes = self._tread_routes or [] router = self._tread_router use_routing = self.training and len(routes) > 0 and torch.is_grad_enabled() for index_block, block in enumerate(self.transformer_blocks): - # tread routing + # TREAD routing for this layer if use_routing: - # check layer routing + # Check if this layer should use routing for route in routes: start_idx = route["start_layer_idx"] end_idx = route["end_layer_idx"] - # handle negative indices + # Handle negative indices if start_idx < 0: start_idx = len(self.transformer_blocks) + start_idx if end_idx < 0: @@ -748,22 +703,13 @@ def forward( hidden_states = router.start_route(hidden_states, mask_info) break if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, - **ckpt_kwargs, ) else: @@ -776,13 +722,13 @@ def custom_forward(*inputs): joint_attention_kwargs=attention_kwargs, ) - # tread end routing + # TREAD end routing for this layer if use_routing: - # check end routing + # Check if this layer should end routing for route in routes: start_idx = route["start_layer_idx"] end_idx = route["end_layer_idx"] - # handle negative indices + # Handle negative indices if start_idx < 0: start_idx = len(self.transformer_blocks) + start_idx if end_idx < 0: @@ -801,13 +747,12 @@ def custom_forward(*inputs): interval_control = int(np.ceil(interval_control)) hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - # use image part from dual-stream + # Use only the image part (hidden_states) from the dual-stream blocks hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - output = self._unflatten_image_latents(output, img_shapes, (patch_h, patch_w)) if USE_PEFT_BACKEND: - # remove lora scale + # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: diff --git a/simpletuner/helpers/models/sana/transformer.py b/simpletuner/helpers/models/sana/transformer.py index 771a2a652..b090e59d3 100644 --- a/simpletuner/helpers/models/sana/transformer.py +++ b/simpletuner/helpers/models/sana/transformer.py @@ -15,7 +15,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import PeftAdapterMixin from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor2_0, SanaLinearAttnProcessor2_0 from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection @@ -289,7 +289,6 @@ class SanaTransformer2DModel(PatchableModule, ModelMixin, ConfigMixin, PeftAdapt _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] _skip_layerwise_casting_patterns = ["patch_embed", "norm"] - @register_to_config def __init__( self, in_channels: int = 32, @@ -312,9 +311,6 @@ def __init__( ) -> None: super().__init__() - out_channels = out_channels or in_channels - inner_dim = num_attention_heads * attention_head_dim - # Normalise patch size to a 2D tuple for consistency with diffusers' PatchEmbed if isinstance(patch_size, int): patch_size_int = patch_size @@ -329,6 +325,30 @@ def __init__( patch_area = patch_size_tuple[0] * patch_size_tuple[1] + effective_out_channels = out_channels or in_channels + self.register_to_config( + in_channels=in_channels, + out_channels=effective_out_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + num_cross_attention_heads=num_cross_attention_heads, + cross_attention_head_dim=cross_attention_head_dim, + cross_attention_dim=cross_attention_dim, + caption_channels=caption_channels, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_bias=attention_bias, + sample_size=sample_size, + patch_size=patch_size_int, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + interpolation_scale=interpolation_scale, + ) + + out_channels = effective_out_channels + inner_dim = num_attention_heads * attention_head_dim + # Ensure the normalised patch size is reflected in the stored config self.config.patch_size = patch_size_int diff --git a/simpletuner/helpers/models/sd3/expanded.py b/simpletuner/helpers/models/sd3/expanded.py index 329821473..606316624 100644 --- a/simpletuner/helpers/models/sd3/expanded.py +++ b/simpletuner/helpers/models/sd3/expanded.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward, _chunked_feed_forward from diffusers.models.attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0 @@ -199,7 +199,6 @@ class SD3TransformerQKNorm2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro _supports_gradient_checkpointing = True - @register_to_config def __init__( self, sample_size: int = 128, @@ -217,7 +216,22 @@ def __init__( ): super().__init__() default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels + effective_out_channels = out_channels if out_channels is not None else default_out_channels + self.register_to_config( + sample_size=sample_size, + patch_size=patch_size, + in_channels=in_channels, + num_layers=num_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + caption_projection_dim=caption_projection_dim, + pooled_projection_dim=pooled_projection_dim, + out_channels=effective_out_channels, + pos_embed_max_size=pos_embed_max_size, + qk_norm=qk_norm, + ) + self.out_channels = effective_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = PatchEmbed( diff --git a/simpletuner/helpers/models/sd3/transformer.py b/simpletuner/helpers/models/sd3/transformer.py index 0776daba7..63c3299d6 100644 --- a/simpletuner/helpers/models/sd3/transformer.py +++ b/simpletuner/helpers/models/sd3/transformer.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import JointTransformerBlock from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0 @@ -58,7 +58,6 @@ class SD3Transformer2DModel(PatchableModule, ModelMixin, ConfigMixin, PeftAdapte _tread_router: Optional[TREADRouter] = None _tread_routes: Optional[List[Dict[str, Any]]] = None - @register_to_config def __init__( self, sample_size: int = 128, @@ -77,7 +76,23 @@ def __init__( ): super().__init__() default_out_channels = in_channels - self.out_channels = out_channels if out_channels is not None else default_out_channels + effective_out_channels = out_channels if out_channels is not None else default_out_channels + self.register_to_config( + sample_size=sample_size, + patch_size=patch_size, + in_channels=in_channels, + num_layers=num_layers, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + joint_attention_dim=joint_attention_dim, + caption_projection_dim=caption_projection_dim, + pooled_projection_dim=pooled_projection_dim, + out_channels=effective_out_channels, + pos_embed_max_size=pos_embed_max_size, + dual_attention_layers=dual_attention_layers, + qk_norm=qk_norm, + ) + self.out_channels = effective_out_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = PatchEmbed( diff --git a/simpletuner/helpers/models/wan/model.py b/simpletuner/helpers/models/wan/model.py index d5e48375d..edfa1bcdf 100644 --- a/simpletuner/helpers/models/wan/model.py +++ b/simpletuner/helpers/models/wan/model.py @@ -1,13 +1,23 @@ import logging import os import random +import threading +from functools import partial +from typing import Dict, Optional import torch -from diffusers import AutoencoderKLWan +from diffusers import AutoencoderKLWan, WanImageToVideoPipeline from torchvision import transforms -from transformers import T5TokenizerFast, UMT5EncoderModel +from transformers import CLIPImageProcessor, CLIPVisionModel, T5TokenizerFast, UMT5EncoderModel -from simpletuner.helpers.models.common import ModelTypes, PipelineTypes, PredictionTypes, VideoModelFoundation, VideoToTensor +from simpletuner.helpers.models.common import ( + ModelTypes, + PipelineConditioningImageEmbedder, + PipelineTypes, + PredictionTypes, + VideoModelFoundation, + VideoToTensor, +) from simpletuner.helpers.models.wan.pipeline import WanPipeline from simpletuner.helpers.models.wan.transformer import WanTransformer3DModel @@ -16,6 +26,7 @@ from simpletuner.helpers.training.multi_process import should_log from simpletuner.helpers.training.tread import TREADRouter +from simpletuner.helpers.training.wrappers import unwrap_model as accelerator_unwrap_model if should_log(): logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO")) @@ -23,6 +34,167 @@ logger.setLevel("ERROR") +def time_text_monkeypatch( + self, + timestep: torch.Tensor, + encoder_hidden_states, + encoder_hidden_states_image=None, + timestep_seq_len=None, +): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (encoder_hidden_states.shape[0], timestep_seq_len)) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +@torch.no_grad() +def add_first_frame_conditioning( + latent_model_input: torch.Tensor, + first_frame: torch.Tensor, + vae: AutoencoderKLWan, +): + """ + Adds first-frame conditioning for Wan 2.1-style I2V models by concatenating + the encoded conditioning latents and mask alongside the noisy latents. + """ + device = latent_model_input.device + dtype = latent_model_input.dtype + vae_scale_factor_temporal = 2 ** sum(getattr(vae, "temperal_downsample", [])) + + _, _, num_latent_frames, latent_height, latent_width = latent_model_input.shape + num_frames = (num_latent_frames - 1) * 4 + 1 + + if first_frame.ndim == 3: + first_frame = first_frame.unsqueeze(0) + if first_frame.shape[0] != latent_model_input.shape[0]: + first_frame = first_frame.expand(latent_model_input.shape[0], -1, -1, -1) + + vae_scale_factor = vae.config.scale_factor_spatial + first_frame = torch.nn.functional.interpolate( + first_frame, + size=( + latent_model_input.shape[3] * vae_scale_factor, + latent_model_input.shape[4] * vae_scale_factor, + ), + mode="bilinear", + align_corners=False, + ) + first_frame = first_frame.unsqueeze(2) + + zero_frame = torch.zeros_like(first_frame) + video_condition = torch.cat( + [first_frame, *[zero_frame for _ in range(num_frames - 1)]], + dim=2, + ) + + latent_condition = vae.encode(video_condition.to(device=device, dtype=dtype)).latent_dist.sample() + latent_condition = latent_condition.to(device=device, dtype=dtype) + + latents_mean = torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(device=device, dtype=dtype) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + device=device, dtype=dtype + ) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones( + latent_model_input.shape[0], + 1, + num_frames, + latent_height, + latent_width, + device=device, + dtype=dtype, + ) + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + latent_model_input.shape[0], + -1, + vae_scale_factor_temporal, + latent_height, + latent_width, + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + + first_frame_condition = torch.concat([mask_lat_size, latent_condition], dim=1) + conditioned_latent = torch.cat([latent_model_input, first_frame_condition], dim=1) + + return conditioned_latent + + +@torch.no_grad() +def add_first_frame_conditioning_v22( + latent_model_input: torch.Tensor, + first_frame: torch.Tensor, + vae: AutoencoderKLWan, + last_frame: Optional[torch.Tensor] = None, +): + """ + Adds first (and optional last) frame conditioning for Wan 2.2-style models that + overwrite latent time steps rather than concatenating additional channels. + """ + device = latent_model_input.device + dtype = latent_model_input.dtype + bs, _, T, H, W = latent_model_input.shape + scale = vae.config.scale_factor_spatial + target_h = H * scale + target_w = W * scale + + if first_frame.ndim == 3: + first_frame = first_frame.unsqueeze(0) + if first_frame.shape[0] != bs: + first_frame = first_frame.expand(bs, -1, -1, -1) + + first_frame_up = torch.nn.functional.interpolate( + first_frame, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ).unsqueeze(2) + encoded = vae.encode(first_frame_up.to(device=device, dtype=dtype)).latent_dist.sample().to(dtype) + + mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device=device, dtype=dtype) + std = 1.0 / torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device=device, dtype=dtype) + encoded = (encoded - mean) * std + + latent = latent_model_input.clone() + latent[:, :, : encoded.shape[2]] = encoded + + mask = torch.ones(bs, 1, T, H, W, device=device, dtype=dtype) + mask[:, :, : encoded.shape[2]] = 0.0 + + if last_frame is not None: + if last_frame.ndim == 3: + last_frame = last_frame.unsqueeze(0) + if last_frame.shape[0] != bs: + last_frame = last_frame.expand(bs, -1, -1, -1) + last_frame_up = torch.nn.functional.interpolate( + last_frame, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ).unsqueeze(2) + last_encoded = vae.encode(last_frame_up.to(device=device, dtype=dtype)).latent_dist.sample().to(dtype) + last_encoded = (last_encoded - mean) * std + latent[:, :, -last_encoded.shape[2] :] = last_encoded + mask[:, :, -last_encoded.shape[2] :] = 0.0 + + return latent, mask + + class Wan(VideoModelFoundation): NAME = "Wan" MODEL_DESCRIPTION = "Video generation model (text-to-video)" @@ -41,6 +213,7 @@ class Wan(VideoModelFoundation): MODEL_SUBFOLDER = "transformer" PIPELINE_CLASSES = { PipelineTypes.TEXT2IMG: WanPipeline, + PipelineTypes.IMG2VIDEO: WanImageToVideoPipeline, # PipelineTypes.IMG2IMG: None, # PipelineTypes.CONTROLNET: None, } @@ -50,11 +223,60 @@ class Wan(VideoModelFoundation): HUGGINGFACE_PATHS = { "t2v-480p-1.3b-2.1": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", "t2v-480p-14b-2.1": "Wan-AI/Wan2.1-T2V-14B-Diffusers", + "i2v-14b-2.1": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", + "i2v-14b-2.1-720p": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", + "i2v-14b-2.2-high": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", + "i2v-14b-2.2-low": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", + "flf2v-14b-2.1": "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers", + "flf2v-14b-2.2-high": "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers", + "flf2v-14b-2.2-low": "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers", + "vace-1.3b-2.1": "Wan-AI/Wan2.1-VACE-1.3B-diffusers", + "vace-14b-2.1": "Wan-AI/Wan2.1-VACE-14B-diffusers", + "ti2v-5b-2.2": "Wan-AI/Wan2.2-TI2V-5B-Diffusers", # "i2v-480p-14b-2.1": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", # "i2v-720p-14b-2.1": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", } MODEL_LICENSE = "apache-2.0" + WAN_STAGE_OVERRIDES: Dict[str, Dict[str, object]] = { + "i2v-14b-2.2-high": { + "trained_stage": "high", + "stage_subfolder": "high_noise_model", + "other_stage_subfolder": "low_noise_model", + "flow_shift": 5.0, + "sample_steps": 40, + "boundary_ratio": 0.90, + "guidance": {"high": 3.5, "low": 3.5}, + }, + "i2v-14b-2.2-low": { + "trained_stage": "low", + "stage_subfolder": "low_noise_model", + "other_stage_subfolder": "high_noise_model", + "flow_shift": 5.0, + "sample_steps": 40, + "boundary_ratio": 0.90, + "guidance": {"high": 3.5, "low": 3.5}, + }, + "flf2v-14b-2.2-high": { + "trained_stage": "high", + "stage_subfolder": "high_noise_model", + "other_stage_subfolder": "low_noise_model", + "flow_shift": 5.0, + "sample_steps": 40, + "boundary_ratio": 0.90, + "guidance": {"high": 3.5, "low": 3.5}, + }, + "flf2v-14b-2.2-low": { + "trained_stage": "low", + "stage_subfolder": "low_noise_model", + "other_stage_subfolder": "high_noise_model", + "flow_shift": 5.0, + "sample_steps": 40, + "boundary_ratio": 0.90, + "guidance": {"high": 3.5, "low": 3.5}, + }, + } + TEXT_ENCODER_CONFIGURATION = { "text_encoder": { "name": "UMT5", @@ -65,6 +287,595 @@ class Wan(VideoModelFoundation): }, } + I2V_FLAVOURS = frozenset( + { + "i2v-14b-2.1", + "i2v-14b-2.1-720p", + "i2v-14b-2.2-high", + "i2v-14b-2.2-low", + } + ) + I2V_CLIP_CONDITIONED_FLAVOURS = frozenset( + { + "i2v-14b-2.1", + "i2v-14b-2.1-720p", + } + ) + FLF2V_FLAVOURS = frozenset( + { + "flf2v-14b-2.1", + "flf2v-14b-2.2-high", + "flf2v-14b-2.2-low", + } + ) + TI2V_FLAVOURS = frozenset( + { + "ti2v-5b-2.2", + } + ) + EXPAND_TIMESTEP_FLAVOURS = frozenset( + { + "ti2v-5b-2.2", + } + ) + + def __init__(self, config, accelerator): + super().__init__(config, accelerator) + self._wan_cached_stage_modules: Dict[str, WanTransformer3DModel] = {} + self._conditioning_image_embedder = None + self._wan_logged_missing_img_encoder = False + self._wan_vae_patch_lock = threading.Lock() + if not hasattr(self.config, "wan_force_2_1_time_embedding"): + self.config.wan_force_2_1_time_embedding = False + self._wan_expand_timesteps = False + + def requires_conditioning_image_embeds(self) -> bool: + if not self._is_i2v_like_flavour(): + return False + + if not self._wan_transformers_require_image_conditioning(): + return False + + pipeline = self.pipelines.get(PipelineTypes.IMG2VIDEO) + if pipeline is not None: + if getattr(pipeline, "image_encoder", None) is None: + if not self._wan_logged_missing_img_encoder: + logger.info( + "Wan flavour %s IMG2VIDEO pipeline missing image encoder; loading conditioning components separately.", + getattr(self.config, "model_flavour", ""), + ) + self._wan_logged_missing_img_encoder = True + elif self._wan_logged_missing_img_encoder: + self._wan_logged_missing_img_encoder = False + + return True + + def _current_flavour(self) -> str: + flavour = getattr(self.config, "model_flavour", None) + return str(flavour or "") + + def _flavour_in(self, collection) -> bool: + return self._current_flavour() in collection + + def requires_conditioning_validation_inputs(self) -> bool: + return self._flavour_in(self.I2V_FLAVOURS | self.FLF2V_FLAVOURS | self.TI2V_FLAVOURS) + + def prepare_batch_conditions(self, batch: dict, state: dict) -> dict: + original_pixels = batch.get("conditioning_pixel_values") + if isinstance(original_pixels, list) and len(original_pixels) > 0: + batch["_wan_conditioning_pixel_values_list"] = original_pixels + else: + batch["_wan_conditioning_pixel_values_list"] = None + + batch = super().prepare_batch_conditions(batch, state) + + pixel_list = batch.pop("_wan_conditioning_pixel_values_list", None) + if pixel_list: + batch["conditioning_pixel_values_multi"] = [ + tensor.to(device=self.accelerator.device) if hasattr(tensor, "to") else tensor for tensor in pixel_list + ] + else: + batch["conditioning_pixel_values_multi"] = None + return batch + + def _is_i2v_like_flavour(self) -> bool: + return self._flavour_in(self.I2V_FLAVOURS | self.FLF2V_FLAVOURS | self.TI2V_FLAVOURS) + + def _uses_last_frame_conditioning(self) -> bool: + return self._flavour_in(self.FLF2V_FLAVOURS) + + def _module_requires_image_conditioning(self, module: Optional[torch.nn.Module]) -> bool: + if module is None: + return False + config = getattr(module, "config", None) + if config is None: + try: + unwrapped = accelerator_unwrap_model(self.accelerator, module) + except Exception: # pragma: no cover - defensive guard + unwrapped = module + config = getattr(unwrapped, "config", None) + if config is None: + return False + image_dim = getattr(config, "image_dim", None) + return image_dim is not None and image_dim != 0 + + def _wan_transformers_require_image_conditioning(self) -> bool: + if getattr(self.config, "wan_disable_conditioning_image_embeds", False): + return False + + if getattr(self.config, "wan_force_conditioning_image_embeds", False): + return True + + model = getattr(self, "model", None) + if model is not None and self._module_requires_image_conditioning(model): + return True + + for cached in self._wan_cached_stage_modules.values(): + if self._module_requires_image_conditioning(cached): + return True + + pipeline = self.pipelines.get(PipelineTypes.IMG2VIDEO) + if pipeline is not None: + if self._module_requires_image_conditioning(getattr(pipeline, "transformer", None)): + return True + if self._module_requires_image_conditioning(getattr(pipeline, "transformer_2", None)): + return True + + flavour = self._current_flavour() + return flavour in self.I2V_CLIP_CONDITIONED_FLAVOURS + + def _extract_conditioning_frames(self, prepared_batch): + multi = prepared_batch.get("conditioning_pixel_values_multi") + first_frame = None + last_frame = None + if multi: + first_frame = multi[0] + if self._uses_last_frame_conditioning() and len(multi) > 1: + last_frame = multi[-1] + else: + candidate = prepared_batch.get("conditioning_pixel_values") + if torch.is_tensor(candidate): + first_frame = candidate + return first_frame, last_frame + + def _mask_to_force_keep(self, mask: torch.Tensor) -> Optional[torch.Tensor]: + transformer = self.unwrap_model(self.model) if getattr(self, "model", None) is not None else None + if transformer is None or not hasattr(transformer, "config"): + return None + patch_size = getattr(transformer.config, "patch_size", (1, 2, 2)) + t_step = max(int(patch_size[0]), 1) + h_step = max(int(patch_size[1]), 1) + w_step = max(int(patch_size[2]), 1) + mask_tokens = mask[:, :, ::t_step, ::h_step, ::w_step] + mask_tokens = mask_tokens.squeeze(1) + force_keep = mask_tokens < 0.5 + return force_keep.flatten(1) + + def _build_expand_timesteps(self, base_timesteps: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + transformer = self.unwrap_model(self.model) if getattr(self, "model", None) is not None else None + if transformer is None or not hasattr(transformer, "config"): + return base_timesteps + patch_size = getattr(transformer.config, "patch_size", (1, 2, 2)) + t_step = max(int(patch_size[0]), 1) + h_step = max(int(patch_size[1]), 1) + w_step = max(int(patch_size[2]), 1) + mask_tokens = mask[:, :, ::t_step, ::h_step, ::w_step] + mask_tokens = mask_tokens.squeeze(1) + base = base_timesteps.to(mask_tokens.device, dtype=mask_tokens.dtype).view(-1, 1, 1, 1) + expanded = (mask_tokens * base).flatten(1) + return expanded.to(device=base_timesteps.device, dtype=base_timesteps.dtype) + + def _wan_prepare_vae_encode_inputs(self, vae, samples: torch.Tensor) -> tuple[torch.Tensor, bool]: + if not torch.is_tensor(samples) or samples.ndim != 5: + return samples, False + + vae_config = getattr(vae, "config", None) + if vae_config is None: + return samples, False + + patch_size = getattr(vae_config, "patch_size", None) + if not isinstance(patch_size, int) or patch_size <= 1: + return samples, False + + in_channels = getattr(vae_config, "in_channels", None) + if not isinstance(in_channels, int): + return samples, False + + batch, channels, frames, height, width = samples.shape + if channels == in_channels: + return samples, False + + expected_channels = channels * (patch_size**2) + if expected_channels != in_channels: + return samples, False + + if height % patch_size != 0 or width % patch_size != 0: + logger.warning( + "Unable to patchify VAE inputs: shape (%s, %s) not divisible by patch size %s.", + height, + width, + patch_size, + ) + return samples, False + + reshaped = samples.contiguous().view( + batch, + channels, + frames, + height // patch_size, + patch_size, + width // patch_size, + patch_size, + ) + patched = reshaped.permute(0, 1, 4, 6, 2, 3, 5).contiguous() + patched = patched.view(batch, expected_channels, frames, height // patch_size, width // patch_size) + return patched, True + + def _wan_encode_without_internal_patchify(self, vae, samples: torch.Tensor, original_patch_size): + config = getattr(vae, "config", None) + if config is None or original_patch_size is None: + return vae.encode(samples) + try: + config.patch_size = None + return vae.encode(samples) + finally: + config.patch_size = original_patch_size + + def encode_with_vae(self, vae, samples): + patched_samples, disable_internal_patch = self._wan_prepare_vae_encode_inputs(vae, samples) + if disable_internal_patch: + original_patch_size = getattr(getattr(vae, "config", None), "patch_size", None) + lock = getattr(self, "_wan_vae_patch_lock", None) + if lock is not None: + with lock: + return self._wan_encode_without_internal_patchify(vae, patched_samples, original_patch_size) + return self._wan_encode_without_internal_patchify(vae, patched_samples, original_patch_size) + return super().encode_with_vae(vae, patched_samples) + + def _apply_i2v_conditioning_to_kwargs(self, prepared_batch, transformer_kwargs): + if not self._is_i2v_like_flavour(): + return + first_frame, last_frame = self._extract_conditioning_frames(prepared_batch) + if first_frame is None or transformer_kwargs.get("hidden_states") is None: + return + + latent_tensor = transformer_kwargs["hidden_states"] + latent_device = latent_tensor.device + latent_dtype = latent_tensor.dtype + + vae = self.get_vae() + if vae is None: + return + try: + vae_param = next(vae.parameters()) + if vae_param.device != latent_device: + vae.to(latent_device) + except StopIteration: + pass + + def _prepare_frame(frame: torch.Tensor) -> torch.Tensor: + if frame.device != latent_device or frame.dtype != vae.dtype: + return frame.to(device=latent_device, dtype=vae.dtype) + return frame + + first_frame_prepared = _prepare_frame(first_frame).detach() + last_frame_prepared = None + if self._uses_last_frame_conditioning() and last_frame is not None: + last_frame_prepared = _prepare_frame(last_frame).detach() + + expand_timesteps = bool(self._wan_expand_timesteps) + with torch.no_grad(): + if expand_timesteps: + conditioned_latent, mask = add_first_frame_conditioning_v22( + latent_tensor, + first_frame_prepared, + vae, + last_frame=last_frame_prepared, + ) + transformer_kwargs["hidden_states"] = conditioned_latent.to(dtype=latent_dtype) + base_timesteps = prepared_batch["timesteps"] + expanded_timesteps = self._build_expand_timesteps(base_timesteps, mask) + transformer_kwargs["timestep"] = expanded_timesteps + force_keep = self._mask_to_force_keep(mask) + if force_keep is not None: + existing = transformer_kwargs.get("force_keep_mask") + transformer_kwargs["force_keep_mask"] = force_keep if existing is None else (existing | force_keep) + else: + conditioned_latent = add_first_frame_conditioning( + latent_tensor, + first_frame_prepared, + vae, + ) + transformer_kwargs["hidden_states"] = conditioned_latent.to(dtype=latent_dtype) + + def _get_conditioning_image_embedder(self): + pipeline = self.pipelines.get(PipelineTypes.IMG2VIDEO) + if pipeline is None: + try: + pipeline = self.get_pipeline(PipelineTypes.IMG2VIDEO) + except Exception: + pipeline = None + + if pipeline is None: + return None + + image_encoder = getattr(pipeline, "image_encoder", None) + image_processor = getattr(pipeline, "image_processor", None) + if image_encoder is None or image_processor is None: + return None + + device = getattr(self.accelerator, "device", torch.device("cpu")) + weight_dtype = getattr(self.config, "weight_dtype", None) + return PipelineConditioningImageEmbedder( + pipeline=pipeline, + image_encoder=image_encoder, + image_processor=image_processor, + device=device, + weight_dtype=weight_dtype, + ) + + def setup_model_flavour(self): + super().setup_model_flavour() + flavour = getattr(self.config, "model_flavour", None) or "" + self._wan_expand_timesteps = flavour in self.EXPAND_TIMESTEP_FLAVOURS + setattr(self.config, "wan_expand_timesteps", self._wan_expand_timesteps) + stage_info = self._wan_stage_info() + if stage_info is None: + return + + if getattr(self.config, "pretrained_transformer_model_name_or_path", None) is None: + self.config.pretrained_transformer_model_name_or_path = self.config.pretrained_model_name_or_path + self.config.pretrained_transformer_subfolder = stage_info["stage_subfolder"] + + self.config.wan_trained_stage = stage_info["trained_stage"] + self.config.wan_stage_main_subfolder = stage_info["stage_subfolder"] + self.config.wan_stage_other_subfolder = stage_info["other_stage_subfolder"] + self.config.wan_boundary_ratio = stage_info["boundary_ratio"] + + self.config.flow_schedule_shift = stage_info["flow_shift"] + self.config.validation_num_inference_steps = stage_info["sample_steps"] + self.config.validation_guidance = stage_info["guidance"][stage_info["trained_stage"]] + + if not hasattr(self.config, "wan_validation_load_other_stage"): + self.config.wan_validation_load_other_stage = False + + def _wan_stage_info(self) -> Optional[Dict[str, object]]: + flavour = getattr(self.config, "model_flavour", None) + return self.WAN_STAGE_OVERRIDES.get(flavour) + + def _apply_time_embedding_override(self, transformer: Optional[WanTransformer3DModel]) -> None: + if transformer is None: + return + target = self.unwrap_model(transformer) + setter = getattr(target, "set_time_embedding_v2_1", None) + if callable(setter): + setter(bool(getattr(self.config, "wan_force_2_1_time_embedding", False))) + + def _patch_condition_embedder(self, transformer: Optional[WanTransformer3DModel]) -> None: + if transformer is None: + return + target = self.unwrap_model(transformer) + embedder = getattr(target, "condition_embedder", None) + if embedder is None: + return + if getattr(embedder, "_simpletuner_time_text_patch", False): + return + embedder.forward = partial(time_text_monkeypatch, embedder) + embedder._simpletuner_time_text_patch = True + + def post_model_load_setup(self): + super().post_model_load_setup() + self._apply_time_embedding_override(getattr(self, "model", None)) + self._patch_condition_embedder(getattr(self, "model", None)) + + def _should_load_other_stage(self) -> bool: + stage_info = self._wan_stage_info() + if stage_info is None: + return False + return bool(getattr(self.config, "wan_validation_load_other_stage", False)) + + def _get_or_load_wan_stage_module(self, subfolder: str) -> WanTransformer3DModel: + if subfolder in self._wan_cached_stage_modules: + return self._wan_cached_stage_modules[subfolder] + + logger.info("Loading Wan stage weights for validation from subfolder '%s'.", subfolder) + stage = self.MODEL_CLASS.from_pretrained( + self.config.pretrained_model_name_or_path, + subfolder=subfolder, + torch_dtype=self.config.weight_dtype, + use_safetensors=True, + ) + stage.requires_grad_(False) + stage.to(self.accelerator.device, dtype=self.config.weight_dtype) + stage.eval() + self._apply_time_embedding_override(stage) + self._patch_condition_embedder(stage) + self._wan_cached_stage_modules[subfolder] = stage + return stage + + def unload_model(self): + super().unload_model() + self._wan_cached_stage_modules.clear() + + def set_prepared_model(self, model, base_model: bool = False): + super().set_prepared_model(model, base_model) + if not base_model: + self._apply_time_embedding_override(self.model) + self._patch_condition_embedder(self.model) + + def get_group_offload_components(self, pipeline): + base_components = super().get_group_offload_components(pipeline) + transformers = {} + for name in ("transformer", "transformer_2"): + module = base_components.get(name) + if module is None: + module = getattr(pipeline, name, None) + if isinstance(module, torch.nn.Module): + transformers[name] = module + return transformers + + def get_pipeline(self, pipeline_type: str = PipelineTypes.TEXT2IMG, load_base_model: bool = True): + pipeline = super().get_pipeline(pipeline_type, load_base_model) + if hasattr(pipeline, "config"): + pipeline.config.expand_timesteps = bool(self._wan_expand_timesteps) + stage_info = self._wan_stage_info() + if stage_info is not None: + load_other = self._should_load_other_stage() + trained_stage = stage_info["trained_stage"] + other_subfolder = stage_info["other_stage_subfolder"] + + if trained_stage == "low": + if load_other: + pipeline.transformer_2 = pipeline.transformer + pipeline.transformer = self._get_or_load_wan_stage_module(other_subfolder) + else: + pipeline.transformer_2 = None + else: + if load_other: + pipeline.transformer_2 = self._get_or_load_wan_stage_module(other_subfolder) + else: + pipeline.transformer_2 = None + + if load_other: + pipeline.config.boundary_ratio = stage_info["boundary_ratio"] + else: + pipeline.config.boundary_ratio = None + + transformer_primary = getattr(pipeline, "transformer", None) + self._apply_time_embedding_override(transformer_primary) + self._patch_condition_embedder(transformer_primary) + if getattr(pipeline, "transformer_2", None) is not None: + self._apply_time_embedding_override(pipeline.transformer_2) + self._patch_condition_embedder(pipeline.transformer_2) + + if hasattr(pipeline, "config"): + pipeline.config.expand_timesteps = bool(self._wan_expand_timesteps) + + return pipeline + + class _ConditioningImageEmbedder: + def __init__(self, image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, device, dtype): + self.image_encoder = image_encoder + self.image_processor = image_processor + self.device = device + self.dtype = dtype + self.image_encoder.eval() + self.image_encoder.to(device=self.device, dtype=self.dtype) + for param in self.image_encoder.parameters(): + param.requires_grad_(False) + + @torch.no_grad() + def encode(self, images): + processed = self.image_processor(images=images, return_tensors="pt") + pixel_values = processed["pixel_values"].to(device=self.device, dtype=self.dtype) + outputs = self.image_encoder(pixel_values=pixel_values, output_hidden_states=True) + hidden = outputs.hidden_states[-2] + return [hidden[i] for i in range(hidden.shape[0])] + + def _load_conditioning_clip_components(self, pipeline): + image_encoder = getattr(pipeline, "image_encoder", None) + image_processor = getattr(pipeline, "image_processor", None) + + if image_encoder is not None and image_processor is not None: + return image_encoder, image_processor + + repo_id = getattr(self.config, "image_encoder_pretrained_model_name_or_path", None) + processor_repo_id = getattr(self.config, "image_processor_pretrained_model_name_or_path", None) + + if repo_id is None: + repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + if processor_repo_id is None: + processor_repo_id = repo_id + if processor_repo_id is None: + processor_repo_id = repo_id + + def build_candidates(user_value, defaults): + candidates = [] + if isinstance(user_value, (list, tuple, set)): + candidates.extend([v for v in user_value if v]) + elif user_value: + candidates.append(user_value) + candidates.extend(defaults) + seen = set() + ordered = [] + for entry in candidates: + if entry in seen: + continue + seen.add(entry) + ordered.append(entry) + return ordered + + encoder_subfolders = build_candidates( + getattr(self.config, "image_encoder_subfolder", None), + ["image_encoder", "vision_encoder", None], + ) + processor_subfolders = build_candidates( + getattr(self.config, "image_processor_subfolder", None), + ["image_processor", "feature_extractor", None], + ) + + encoder_errors = [] + for subfolder in encoder_subfolders: + try: + kwargs = {"use_safetensors": True} + if subfolder is not None: + kwargs["subfolder"] = subfolder + image_encoder = CLIPVisionModel.from_pretrained(repo_id, **kwargs) + break + except Exception as exc: # pragma: no cover - defensive + encoder_errors.append(f"{repo_id}/{subfolder or '.'}: {exc}") + + if image_encoder is None: + raise ValueError( + "Unable to load a CLIP vision encoder for conditioning image embeddings. " + "Set `image_encoder_pretrained_model_name_or_path` (and optionally `image_encoder_subfolder`) to a " + "compatible repository. Attempts failed with: " + "; ".join(encoder_errors) + ) + + processor_errors = [] + for subfolder in processor_subfolders: + try: + kwargs = {} + if subfolder is not None: + kwargs["subfolder"] = subfolder + image_processor = CLIPImageProcessor.from_pretrained(processor_repo_id, **kwargs) + break + except Exception as exc: # pragma: no cover - defensive + processor_errors.append(f"{processor_repo_id}/{subfolder or '.'}: {exc}") + + if image_processor is None: + raise ValueError( + "Unable to load a CLIP image processor for conditioning image embeddings. " + "Set `image_processor_pretrained_model_name_or_path` (and optionally `image_processor_subfolder`). " + "Attempts failed with: " + "; ".join(processor_errors) + ) + + if pipeline is not None: + pipeline.image_encoder = image_encoder + pipeline.image_processor = image_processor + + return image_encoder, image_processor + + def get_conditioning_image_embedder(self): + if self._conditioning_image_embedder is not None: + return self._conditioning_image_embedder + + pipeline = self.get_pipeline(PipelineTypes.IMG2VIDEO) + image_encoder, image_processor = self._load_conditioning_clip_components(pipeline) + + device = getattr(self.accelerator, "device", torch.device("cpu")) + dtype = getattr(self.config, "weight_dtype", torch.float32) + if isinstance(dtype, str): + dtype = getattr(torch, dtype, torch.float32) + + self._conditioning_image_embedder = self._ConditioningImageEmbedder( + image_encoder=image_encoder, + image_processor=image_processor, + device=device, + dtype=dtype, + ) + return self._conditioning_image_embedder + def tread_init(self): """ Initialize the TREAD model training method for Wan. @@ -97,7 +908,28 @@ def update_pipeline_call_kwargs(self, pipeline_kwargs): # Wan video should max out around 81 frames for efficiency. pipeline_kwargs["num_frames"] = min(81, self.config.validation_num_video_frames or 81) pipeline_kwargs["output_type"] = "pil" - # replace embeds with prompt + + input_image = pipeline_kwargs.get("image") + if isinstance(input_image, list): + if len(input_image) > 0: + pipeline_kwargs["image"] = input_image[0] + if self._uses_last_frame_conditioning() and len(input_image) > 1: + pipeline_kwargs["last_image"] = input_image[-1] + elif self._uses_last_frame_conditioning() and input_image is not None and "last_image" not in pipeline_kwargs: + pipeline_kwargs["last_image"] = input_image + + stage_info = self._wan_stage_info() + if stage_info is not None: + trained_stage = stage_info["trained_stage"] + pipeline_kwargs["num_inference_steps"] = stage_info["sample_steps"] + pipeline_kwargs["guidance_scale"] = stage_info["guidance"][trained_stage] + if self._should_load_other_stage(): + other_stage = "low" if trained_stage == "high" else "high" + pipeline_kwargs["guidance_scale_2"] = stage_info["guidance"][other_stage] + else: + pipeline_kwargs.pop("guidance_scale_2", None) + else: + pipeline_kwargs.pop("guidance_scale_2", None) return pipeline_kwargs @@ -172,6 +1004,13 @@ def model_predict(self, prepared_batch): "return_dict": False, } + if prepared_batch.get("conditioning_image_embeds") is not None: + wan_transformer_kwargs["encoder_hidden_states_image"] = prepared_batch["conditioning_image_embeds"].to( + self.config.weight_dtype + ) + + self._apply_i2v_conditioning_to_kwargs(prepared_batch, wan_transformer_kwargs) + # For masking with TREAD, avoid dropping any tokens that are in the mask if ( getattr(self.config, "tread_config", None) is not None @@ -204,7 +1043,10 @@ def model_predict(self, prepared_batch): # After transpose(1,2): (B, T'*H'*W', D) # So we flatten the mask with the same T->H->W order force_keep = mask_tok.squeeze(1).flatten(1) > 0.5 # (B, S_vid) - wan_transformer_kwargs["force_keep_mask"] = force_keep + existing_force_keep = wan_transformer_kwargs.get("force_keep_mask") + wan_transformer_kwargs["force_keep_mask"] = ( + force_keep if existing_force_keep is None else (existing_force_keep | force_keep) + ) model_pred = self.model(**wan_transformer_kwargs)[0] @@ -216,6 +1058,19 @@ def check_user_config(self): """ Checks self.config values against important issues. """ + stage_info = self._wan_stage_info() + if stage_info is not None: + trained_stage = stage_info["trained_stage"] + self.config.validation_guidance = stage_info["guidance"][trained_stage] + if hasattr(self.config, "validation_guidance_skip_layers"): + self.config.validation_guidance_skip_layers = None + if hasattr(self.config, "validation_guidance_skip_layers_start"): + self.config.validation_guidance_skip_layers_start = None + if hasattr(self.config, "validation_guidance_skip_layers_stop"): + self.config.validation_guidance_skip_layers_stop = None + if hasattr(self.config, "validation_guidance_skip_scale"): + self.config.validation_guidance_skip_scale = None + if self.config.base_model_precision == "fp8-quanto": raise ValueError( f"{self.NAME} does not support fp8-quanto. Please use fp8-torchao or int8 precision level instead." diff --git a/simpletuner/helpers/models/wan/pipeline.py b/simpletuner/helpers/models/wan/pipeline.py index 4624fd95f..9db8f325e 100644 --- a/simpletuner/helpers/models/wan/pipeline.py +++ b/simpletuner/helpers/models/wan/pipeline.py @@ -13,12 +13,13 @@ # limitations under the License. import html -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import ftfy import regex as re import torch from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput from diffusers.loaders import WanLoraLoaderMixin from diffusers.models import AutoencoderKLWan, WanTransformer3DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline @@ -27,7 +28,8 @@ from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from transformers import AutoTokenizer, UMT5EncoderModel +from PIL import Image +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -51,8 +53,6 @@ >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) >>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) - >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) >>> pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." @@ -88,30 +88,25 @@ def prompt_clean(text): return text +def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" - Pipeline for text-to-video generation using Wan. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - Args: - tokenizer ([`T5Tokenizer`]): - Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), - specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - text_encoder ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanTransformer3DModel`]): - Conditional Transformer to denoise the input latents. - scheduler ([`UniPCMultistepScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + Pipeline for Wan text-to-video and image-to-video generation. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"] def __init__( self, @@ -120,6 +115,11 @@ def __init__( transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModel] = None, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, ): super().__init__() @@ -127,13 +127,30 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + image_encoder=image_encoder, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler, + image_processor=image_processor, ) + self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps) - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + def _requires_image_conditioning(self) -> bool: + """Return True if either transformer expects image conditioning.""" + for module in [getattr(self, "transformer", None), getattr(self, "transformer_2", None)]: + if module is None or not hasattr(module, "config"): + continue + if getattr(module.config, "image_dim", None) is not None: + return True + return False def _get_t5_prompt_embeds( self, @@ -143,8 +160,12 @@ def _get_t5_prompt_embeds( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): - device = device or self._execution_device - dtype = dtype or self.text_encoder.dtype + encoder_param = next(self.text_encoder.parameters(), None) + encoder_device = encoder_param.device if encoder_param is not None else torch.device("cpu") + encoder_dtype = encoder_param.dtype if encoder_param is not None else torch.float32 + + target_device = device or self._execution_device or encoder_device + target_dtype = dtype or encoder_dtype prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt_clean(u) for u in prompt] @@ -162,15 +183,16 @@ def _get_t5_prompt_embeds( text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = self.text_encoder( + text_input_ids.to(encoder_device), + mask.to(encoder_device), + ).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=target_dtype, device=target_device) prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], - dim=0, + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 ) - # duplicate text embeddings for each generation per prompt, using mps friendly method _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) @@ -189,32 +211,6 @@ def encode_prompt( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. torch device to place the resulting embeddings on - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - """ device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt @@ -258,15 +254,31 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - def check_inputs( + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + if self.image_processor is None or self.image_encoder is None: + raise ValueError("Image processor and image encoder must be provided for image-to-video generation.") + + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # ------------------------------------------------------------------------- + # Input validation + # ------------------------------------------------------------------------- + def _check_t2v_inputs( self, prompt, negative_prompt, height, width, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -299,17 +311,98 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - def prepare_latents( + def _check_i2v_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + image_embeds, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + + self._check_t2v_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if self.config.boundary_ratio is not None and image_embeds is not None: + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2: Optional[float] = None, + image_embeds: Optional[torch.Tensor] = None, + ): + if self._requires_image_conditioning() or image is not None or image_embeds is not None: + self._check_i2v_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + image_embeds, + ) + else: + self._check_t2v_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + # ------------------------------------------------------------------------- + # Latent preparation + # ------------------------------------------------------------------------- + def _prepare_t2v_latents( self, batch_size: int, - num_channels_latents: int = 16, - height: int = 480, - width: int = 832, - num_frames: int = 81, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: Optional[torch.dtype], + device: Optional[torch.device], + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + latents: Optional[torch.Tensor], ) -> torch.Tensor: if latents is not None: return latents.to(device=device, dtype=dtype) @@ -331,6 +424,95 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents + def _prepare_i2v_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + num_frames: int, + dtype: Optional[torch.dtype], + device: Optional[torch.device], + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + latents: Optional[torch.Tensor], + last_image: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + + if self.config.expand_timesteps: + video_condition = image + elif last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], + dim=2, + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + if self.config.expand_timesteps: + first_frame_mask = torch.ones(1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device) + first_frame_mask[:, :, 0] = 0 + return latents, latent_condition, first_frame_mask + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1), None + + # ------------------------------------------------------------------------- + # Execution + # ------------------------------------------------------------------------- @property def guidance_scale(self): return self._guidance_scale @@ -355,130 +537,47 @@ def interrupt(self): def attention_kwargs(self): return self._attention_kwargs - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( + def _call_text_to_video( self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - skip_guidance_layers: List[int] = None, - skip_layer_guidance_start: float = 0.1, - skip_layer_guidance_stop: float = 0.3, - skip_layer_guidance_scale: float = 1.0, - num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - output_type: Optional[str] = "np", - return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[ - Union[ - Callable[[int, int, Dict], None], - PipelineCallback, - MultiPipelineCallbacks, - ] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + guidance_scale, + skip_guidance_layers, + skip_layer_guidance_start, + skip_layer_guidance_stop, + skip_layer_guidance_scale, + num_videos_per_prompt, + generator, + latents, + prompt_embeds, + negative_prompt_embeds, + output_type, + return_dict, + attention_kwargs, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + max_sequence_length, ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - height (`int`, defaults to `480`): - The height in pixels of the generated image. - width (`int`, defaults to `832`): - The width in pixels of the generated image. - num_frames (`int`, defaults to `81`): - The number of frames in the generated video. - num_inference_steps (`int`, defaults to `50`): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, defaults to `5.0`): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - skip_guidance_layers (`List[int]`, *optional*, defaults to `None`): - The indices of the layers to skip in the transformer. The indices are 0-indexed. - skip_layer_guidance_stop (`float`, *optional*, defaults to `0.1`): - The fraction of the total number of inference steps at which to start the skip connection. - skip_layer_guidance_start (`float`, *optional*, defaults to `0.3`): - The fraction of the total number of inference steps at which to end the skip connection. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): - A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of - each denoising step during the inference. with the following arguments: `callback_on_step_end(self: - DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a - list of all tensors as specified by `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. - - Examples: - - Returns: - [`~WanPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images and the second element is a list of `bool`s - indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. - """ - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, negative_prompt, - height, - width, - prompt_embeds, - negative_prompt_embeds, - callback_on_step_end_tensor_inputs, + image=None, + height=height, + width=width, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) - self._guidance_scale = guidance_scale - self._attention_kwargs = attention_kwargs - self._current_timestep = None - self._interrupt = False - device = self._execution_device - # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -486,11 +585,10 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, + do_classifier_free_guidance=guidance_scale > 1.0, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -503,13 +601,11 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( + num_channels_latents = self.vae.config.z_dim + latents = self._prepare_t2v_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -521,16 +617,22 @@ def __call__( latents, ) - # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # expand t to match the batch size + if self.interrupt: + continue + timestep = t.expand(latents.shape[0]) + self._current_timestep = t - # Unconditional pass (possibly skip-layers) if self.do_classifier_free_guidance: fraction = i / float(num_inference_steps) skip_layer_indices = skip_guidance_layers @@ -548,7 +650,6 @@ def __call__( else: noise_pred_uncond = None - # Positive pass (no skip-layers) noise_pred_text = self.transformer( hidden_states=latents.to(transformer_dtype), timestep=timestep, @@ -557,15 +658,11 @@ def __call__( return_dict=False, )[0] - # Combine for CFG if self.do_classifier_free_guidance: - # noise_pred = uncond + w * (text - uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) else: - # if no CFG, just use text pass noise_pred = noise_pred_text - # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if callback_on_step_end is not None: @@ -578,7 +675,6 @@ def __call__( prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -587,7 +683,233 @@ def __call__( self._current_timestep = None - if not output_type == "latent": + if output_type != "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + + def _call_image_to_video( + self, + image: PipelineImageInput, + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + guidance_scale, + guidance_scale_2, + num_videos_per_prompt, + generator, + latents, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + last_image, + output_type, + return_dict, + attention_kwargs, + callback_on_step_end, + callback_on_step_end_tensor_inputs, + max_sequence_length, + ): + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + image_embeds, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=True, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_for_dtype = self.transformer or self.transformer_2 + transformer_dtype = transformer_for_dtype.dtype if transformer_for_dtype is not None else torch.float32 + + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if self.transformer is not None and getattr(self.transformer.config, "image_dim", None) is not None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_channels_latents = self.vae.config.z_dim + image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + + latents_outputs = self._prepare_i2v_latents( + image_tensor, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) + if self.config.expand_timesteps: + latents, condition, first_frame_mask = latents_outputs + else: + latents, condition, first_frame_mask = latents_outputs[0], latents_outputs[1], None + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.config.boundary_ratio is not None and self.transformer is not None and self.transformer_2 is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if self.config.expand_timesteps: + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(transformer_dtype) + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + use_low_stage = boundary_timestep is not None and t < boundary_timestep and self.transformer_2 is not None + current_model = None + current_guidance_scale = guidance_scale + + if use_low_stage: + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 if guidance_scale_2 is not None else guidance_scale + else: + current_model = self.transformer or self.transformer_2 + + if current_model is None: + raise ValueError("No transformer available to process the current timestep.") + + with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if self.config.expand_timesteps: + latents = (1 - first_frame_mask) * condition + first_frame_mask * latents + + if output_type != "latent": latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -603,10 +925,106 @@ def __call__( else: video = latents - # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) return WanPipelineOutput(frames=video) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + image: Optional[PipelineImageInput] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + skip_guidance_layers: Optional[List[int]] = None, + skip_layer_guidance_start: float = 0.1, + skip_layer_guidance_stop: float = 0.3, + skip_layer_guidance_scale: float = 1.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[ + Callable[[int, int, Dict], None], + PipelineCallback, + MultiPipelineCallbacks, + ] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + """ + Generate a video sequence from text prompts and optional conditioning inputs. + + Examples: + Placeholder for dynamically injected example. + """ + use_i2v = self._requires_image_conditioning() or image is not None or image_embeds is not None + + if use_i2v: + return self._call_image_to_video( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + guidance_scale_2=guidance_scale_2, + num_videos_per_prompt=num_videos_per_prompt, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + last_image=last_image, + output_type=output_type, + return_dict=return_dict, + attention_kwargs=attention_kwargs, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + skip_guidance_layers = skip_guidance_layers or [] + return self._call_text_to_video( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + skip_guidance_layers=skip_guidance_layers, + skip_layer_guidance_start=skip_layer_guidance_start, + skip_layer_guidance_stop=skip_layer_guidance_stop, + skip_layer_guidance_scale=skip_layer_guidance_scale, + num_videos_per_prompt=num_videos_per_prompt, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + output_type=output_type, + return_dict=return_dict, + attention_kwargs=attention_kwargs, + callback_on_step_end=callback_on_step_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) diff --git a/simpletuner/helpers/models/wan/transformer.py b/simpletuner/helpers/models/wan/transformer.py index 23fbd6153..766fd2c9c 100644 --- a/simpletuner/helpers/models/wan/transformer.py +++ b/simpletuner/helpers/models/wan/transformer.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.configuration_utils import ConfigMixin from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import Attention @@ -377,7 +377,6 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] - @register_to_config def __init__( self, patch_size: Tuple[int] = (1, 2, 2), @@ -397,9 +396,27 @@ def __init__( rope_max_seq_len: int = 1024, ) -> None: super().__init__() + effective_out_channels = out_channels or in_channels + self.register_to_config( + patch_size=patch_size, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=effective_out_channels, + text_dim=text_dim, + freq_dim=freq_dim, + ffn_dim=ffn_dim, + num_layers=num_layers, + cross_attn_norm=cross_attn_norm, + qk_norm=qk_norm, + eps=eps, + image_dim=image_dim, + added_kv_proj_dim=added_kv_proj_dim, + rope_max_seq_len=rope_max_seq_len, + ) inner_dim = num_attention_heads * attention_head_dim - out_channels = out_channels or in_channels + out_channels = effective_out_channels # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -437,6 +454,18 @@ def __init__( self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False + self.force_v2_1_time_embedding: bool = False + + def set_time_embedding_v2_1(self, force_2_1_time_embedding: bool) -> None: + """ + Force the Wan transformer to use 2.1-style time embeddings even when running Wan 2.2 checkpoints. + + Args: + force_2_1_time_embedding: Whether to override the default time embedding behaviour. + """ + self.force_v2_1_time_embedding = bool(force_2_1_time_embedding) + if self.force_v2_1_time_embedding: + logger.info("WanTransformer3DModel: Forcing Wan 2.1 style time embedding.") def set_router(self, router: TREADRouter, routes: List[Dict[str, Any]]): """Set the TREAD router and routing configuration.""" @@ -502,6 +531,11 @@ def forward( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + if self.force_v2_1_time_embedding and timestep.dim() > 1: + # Wan 2.1 uses a single timestep per batch entry. When forcing 2.1 behaviour with Wan 2.2 + # checkpoints we fall back to the first timestep value which matches the reference implementation. + timestep = timestep[..., 0].contiguous() + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) diff --git a/simpletuner/helpers/multiaspect/dataset.py b/simpletuner/helpers/multiaspect/dataset.py index 232b9bb6e..e2edc26ab 100644 --- a/simpletuner/helpers/multiaspect/dataset.py +++ b/simpletuner/helpers/multiaspect/dataset.py @@ -43,6 +43,12 @@ def __len__(self): return sum([len(dataset) for dataset in self.datasets]) def __getitem__(self, image_tuple: list[dict[str, Any] | TrainingSample]): + state_args = StateTracker.get_args() + model_family = "" + if state_args is not None: + model_family = getattr(state_args, "model_family", "") or "" + model_family = str(model_family) + output_data = { "training_samples": [], "conditioning_samples": [], @@ -67,7 +73,7 @@ def __getitem__(self, image_tuple: list[dict[str, Any] | TrainingSample]): f"Aspect ratios must be the same for all images in a batch. Expected: {first_aspect_ratio}, got: {calculated_aspect_ratio}" ) - if "deepfloyd" not in StateTracker.get_args().model_family and ( + if "deepfloyd" not in model_family and ( image_metadata["original_size"] is None or image_metadata["target_size"] is None ): raise Exception( diff --git a/simpletuner/helpers/multiaspect/image.py b/simpletuner/helpers/multiaspect/image.py index 35e692ae7..9d2bc6903 100644 --- a/simpletuner/helpers/multiaspect/image.py +++ b/simpletuner/helpers/multiaspect/image.py @@ -13,19 +13,26 @@ logger = logging.getLogger("MultiaspectImage") logger.setLevel(os.environ.get("SIMPLETUNER_IMAGE_PREP_LOG_LEVEL", "INFO")) +from numbers import Real + +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + class MultiaspectImage: @staticmethod def _coerce_positive_int(value, default: int = 1) -> int: """Return a positive integer from value or fallback to default.""" if isinstance(value, Real): - candidate = int(value) + result = int(value) else: try: - candidate = int(value) + result = int(value) except (TypeError, ValueError): return default - return candidate if candidate > 0 else default + return result if result > 0 else default @staticmethod def _get_alignment(default: int = 1) -> int: @@ -37,10 +44,16 @@ def _get_alignment(default: int = 1) -> int: def _get_rounding(default: int) -> int: args = StateTracker.get_args() rounding = getattr(args, "aspect_bucket_rounding", None) if args is not None else None - if rounding is None: + if isinstance(rounding, Real): + rounding = int(rounding) + else: + try: + rounding = int(rounding) + except (TypeError, ValueError): + rounding = None + if rounding is None or rounding < 0: return default - rounding = MultiaspectImage._coerce_positive_int(rounding, default) - return rounding if rounding >= 0 else default + return rounding @staticmethod def limit_canvas_size(width: int, height: int, max_size: int) -> dict: @@ -92,7 +105,7 @@ def _round_to_nearest_multiple(value, override_value: int = None): multiple = MultiaspectImage._get_alignment() else: multiple = MultiaspectImage._coerce_positive_int(multiple, default=1) - rounded = int(round(value / multiple) * multiple) + rounded = round(value / multiple) * multiple return max(rounded, multiple) # Ensure it's at least the value of 'multiple' @staticmethod @@ -298,9 +311,7 @@ def calculate_image_aspect_ratio(image, rounding: int = 2): Returns: float: The rounded aspect ratio of the image. """ - to_round = StateTracker.get_args().aspect_bucket_rounding - if to_round is None: - to_round = rounding + to_round = MultiaspectImage._get_rounding(rounding) if isinstance(image, Image.Image): # An actual image was passed in. width, height = image.size diff --git a/simpletuner/helpers/training/collate.py b/simpletuner/helpers/training/collate.py index 6885e5b70..1340f4843 100644 --- a/simpletuner/helpers/training/collate.py +++ b/simpletuner/helpers/training/collate.py @@ -405,6 +405,8 @@ def collate_fn(batch): batch_luminance = sum(batch_luminance) / len(batch_luminance) debug_log("Extract filepaths") filepaths = extract_filepaths(examples) + data_backend = StateTracker.get_data_backend(data_backend_id) + debug_log("Compute latents") model = StateTracker.get_model() batch_data = compute_latents(filepaths, data_backend_id, model) @@ -416,13 +418,28 @@ def collate_fn(batch): debug_log("Check latents") latent_batch = check_latent_shapes(latent_batch, filepaths, data_backend_id, examples) + conditioning_image_embeds = None + if model.requires_conditioning_image_embeds(): + cache = data_backend.get("conditioning_image_embed_cache") + if cache is None: + raise ValueError("Conditioning image embed cache is required but was not configured.") + embed_tensors = [] + for path in filepaths: + embed_tensor = cache.retrieve_from_cache(path) + if isinstance(embed_tensor, dict): + raise ValueError("Conditioning image embed cache returned an unexpected structure.") + if not torch.backends.mps.is_available(): + embed_tensor = embed_tensor.to("cpu").pin_memory() + embed_tensors.append(embed_tensor) + if embed_tensors: + conditioning_image_embeds = torch.stack(embed_tensors, dim=0) + training_filepaths = [] conditioning_type = None conditioning_pixel_values = None conditioning_latents = None # get multiple backend ids - data_backend = StateTracker.get_data_backend(data_backend_id) conditioning_backends = data_backend.get("conditioning_data", []) if len(conditioning_examples) > 0: # check the # of conditioning backends @@ -553,6 +570,7 @@ def collate_fn(batch): "batch_luminance": batch_luminance, "conditioning_pixel_values": conditioning_pixel_values, "conditioning_latents": conditioning_latents, + "conditioning_image_embeds": conditioning_image_embeds, "encoder_attention_mask": all_text_encoder_outputs.get("attention_masks"), "is_regularisation_data": is_regularisation_data, "is_i2v_data": is_i2v_data, diff --git a/simpletuner/helpers/training/state_tracker.py b/simpletuner/helpers/training/state_tracker.py index 9e4056ace..e9ab683b6 100644 --- a/simpletuner/helpers/training/state_tracker.py +++ b/simpletuner/helpers/training/state_tracker.py @@ -20,6 +20,7 @@ "all_image_files": "image", "all_vae_cache_files": "vae", "all_text_cache_files": "text", + "all_conditioning_image_embed_files": "conditioning_image_embeds", } @@ -42,6 +43,7 @@ class StateTracker: all_image_files = {} all_vae_cache_files = {} all_text_cache_files = {} + all_conditioning_image_embed_files = {} all_caption_files = None ## Backend entities for retrieval @@ -77,6 +79,7 @@ def delete_cache_files(cls, data_backend_id: str = None, preserve_data_backend_c "all_image_files", "all_vae_cache_files", "all_text_cache_files", + "all_conditioning_image_embed_files", ]: if filename_mapping[cache_name] in str(preserve_data_backend_cache): continue @@ -372,6 +375,36 @@ def get_vae_cache_files(cls: list, data_backend_id: str, retry_limit: int = 0): cls.all_vae_cache_files[data_backend_id] = cls._load_from_disk("all_vae_cache_files_{}".format(data_backend_id)) return cls.all_vae_cache_files[data_backend_id] or {} + @classmethod + def set_conditioning_image_embed_files(cls, raw_file_list: list, data_backend_id: str): + raw_file_list = raw_file_list or [] + if cls.all_conditioning_image_embed_files.get(data_backend_id) is not None: + cls.all_conditioning_image_embed_files[data_backend_id].clear() + else: + cls.all_conditioning_image_embed_files[data_backend_id] = {} + for subdirectory_list in raw_file_list: + if isinstance(subdirectory_list, (list, tuple)) and len(subdirectory_list) == 3: + _, _, files = subdirectory_list + else: + files = subdirectory_list if isinstance(subdirectory_list, list) else [subdirectory_list] + for embed_path in files: + cls.all_conditioning_image_embed_files[data_backend_id][embed_path] = False + cls._save_to_disk( + f"all_conditioning_image_embed_files_{data_backend_id}", + cls.all_conditioning_image_embed_files[data_backend_id], + ) + + @classmethod + def get_conditioning_image_embed_files(cls, data_backend_id: str, retry_limit: int = 0): + if ( + data_backend_id not in cls.all_conditioning_image_embed_files + or cls.all_conditioning_image_embed_files.get(data_backend_id) is None + ): + cls.all_conditioning_image_embed_files[data_backend_id] = cls._load_from_disk( + f"all_conditioning_image_embed_files_{data_backend_id}", retry_limit=retry_limit + ) + return cls.all_conditioning_image_embed_files[data_backend_id] or {} + @classmethod def set_text_cache_files(cls, raw_file_list: list, data_backend_id: str): if cls.all_text_cache_files[data_backend_id] is not None: diff --git a/simpletuner/helpers/training/trainer.py b/simpletuner/helpers/training/trainer.py index 24a5d7159..a3fdc1a20 100644 --- a/simpletuner/helpers/training/trainer.py +++ b/simpletuner/helpers/training/trainer.py @@ -18,8 +18,8 @@ from typing import Dict, List, Optional import huggingface_hub - import wandb + from simpletuner.helpers import log_format # noqa from simpletuner.helpers.caching.memory import reclaim_memory from simpletuner.helpers.configuration.cli_utils import mapping_to_cli_args @@ -238,9 +238,7 @@ def parse_arguments(self, args=None, disable_accelerator: bool = False, exit_on_ if isinstance(args_payload, dict): skip_config_fallback = bool(args_payload.pop("__skip_config_fallback__", False)) # Strip any internal metadata entries that shouldn't be forwarded to the CLI parser. - metadata_keys = [ - key for key in list(args_payload.keys()) if isinstance(key, str) and key.startswith("__") - ] + metadata_keys = [key for key in list(args_payload.keys()) if isinstance(key, str) and key.startswith("__")] for key in metadata_keys: args_payload.pop(key, None) @@ -474,6 +472,9 @@ def _coerce_flag(value: object) -> bool: self.accelerator = Accelerator(**accelerator_kwargs) else: raise + if self.accelerator: + os.environ["RANK"] = str(self.accelerator.process_index) + os.environ["WORLD_SIZE"] = str(self.accelerator.num_processes) self._setup_accelerator_barrier_guard() fsdp_active = False if self.accelerator and hasattr(self.accelerator, "state"): @@ -3686,6 +3687,9 @@ def _launch_with_accelerate() -> Optional[int]: cli_args: list[str] = [] if isinstance(config_payload, dict): train_cli_payload = dict(config_payload) + metadata_keys = [key for key in list(train_cli_payload.keys()) if isinstance(key, str) and key.startswith("__")] + for key in metadata_keys: + train_cli_payload.pop(key, None) for accel_key in { "accelerate_config", "--accelerate_config", diff --git a/simpletuner/helpers/training/validation.py b/simpletuner/helpers/training/validation.py index 8721c17d1..6ee216922 100644 --- a/simpletuner/helpers/training/validation.py +++ b/simpletuner/helpers/training/validation.py @@ -7,9 +7,9 @@ import diffusers import numpy as np import torch +import wandb from tqdm import tqdm -import wandb from simpletuner.helpers.models.common import ImageModelFoundation, ModelFoundation, VideoModelFoundation from simpletuner.helpers.training.wrappers import unwrap_model @@ -373,7 +373,7 @@ def prepare_validation_prompt_list(args, embed_cache, model): embed_cache.compute_embeddings_for_prompts([args.validation_prompt], is_validation=True, load_from_cache=False) # Compute negative embed for validation prompts, if any are set, so that it's stored before we unload the text encoder. if validation_prompts: - logger.info("Precomputing the negative prompt embed for validations.") + logger.info(f"Precomputing the negative prompt embed for validations. Prompts: {validation_prompts}") model.log_model_devices() validation_negative_prompt_text_encoder_output = embed_cache.compute_embeddings_for_prompts( [StateTracker.get_args().validation_negative_prompt], diff --git a/simpletuner/helpers/utils/offloading.py b/simpletuner/helpers/utils/offloading.py new file mode 100644 index 000000000..908c6c290 --- /dev/null +++ b/simpletuner/helpers/utils/offloading.py @@ -0,0 +1,105 @@ +import os +from typing import Dict, Iterable, Optional + +import torch + +try: + from diffusers.hooks import apply_group_offloading + from diffusers.hooks.group_offloading import _is_group_offload_enabled + + _DIFFUSERS_GROUP_OFFLOAD_AVAILABLE = True +except ImportError: # pragma: no cover - handled by runtime checks + apply_group_offloading = None # type: ignore[assignment] + _is_group_offload_enabled = None # type: ignore[assignment] + _DIFFUSERS_GROUP_OFFLOAD_AVAILABLE = False + + +def enable_group_offload_on_components( + components: Dict[str, torch.nn.Module], + *, + device: torch.device, + offload_type: str = "block_level", + number_blocks_per_group: Optional[int] = 1, + use_stream: bool = False, + record_stream: bool = False, + low_cpu_mem_usage: bool = False, + non_blocking: bool = False, + offload_to_disk_path: Optional[str] = None, + exclude: Optional[Iterable[str]] = None, + required_import_error_message: str = "Group offloading requires diffusers>=0.33.0", +) -> None: + """ + Apply diffusers group offloading to a set of pipeline components. + + Parameters + ---------- + components: + Dictionary of pipeline components (module name -> instance). + device: + Target device for on-loading modules (typically the accelerator device). + offload_type: + "block_level" (default) or "leaf_level". + number_blocks_per_group: + Number of blocks per group when using block-level offloading. + use_stream: + Whether to use CUDA streams for asynchronous transfers. + record_stream / low_cpu_mem_usage / non_blocking: + Additional flags routed to diffusers group offloading helpers. + offload_to_disk_path: + Optional directory to spill parameters to disk. + exclude: + Optional iterable of component names to skip (defaults to ["vae", "vqvae"]). + required_import_error_message: + Custom message if diffusers does not expose group offloading utilities. + """ + + if not _DIFFUSERS_GROUP_OFFLOAD_AVAILABLE: + raise ImportError(required_import_error_message) + + onload_device = torch.device(device) + offload_device = torch.device("cpu") + + if offload_to_disk_path: + os.makedirs(offload_to_disk_path, exist_ok=True) + + excluded_names = set(exclude or []) + if "vae" not in excluded_names: + excluded_names.add("vae") + if "vqvae" not in excluded_names: + excluded_names.add("vqvae") + + for name, module in components.items(): + if name in excluded_names: + continue + + if module is None or not isinstance(module, torch.nn.Module): + continue + + if _is_group_offload_enabled(module): # type: ignore[operator] + continue + + kwargs = { + "offload_type": offload_type, + "use_stream": use_stream, + "record_stream": record_stream, + "low_cpu_mem_usage": low_cpu_mem_usage, + "non_blocking": non_blocking, + "offload_to_disk_path": offload_to_disk_path, + } + + if offload_type == "block_level" and number_blocks_per_group is not None: + kwargs["num_blocks_per_group"] = number_blocks_per_group + + if hasattr(module, "enable_group_offload"): + module.enable_group_offload( # type: ignore[call-arg] + onload_device=onload_device, + offload_device=offload_device, + **kwargs, + ) + else: + apply_group_offloading( # type: ignore[misc] + module=module, + onload_device=onload_device, + offload_device=offload_device, + **kwargs, + ) diff --git a/simpletuner/service_worker.py b/simpletuner/service_worker.py index 4105a020c..eb247f0c0 100644 --- a/simpletuner/service_worker.py +++ b/simpletuner/service_worker.py @@ -18,6 +18,7 @@ from simpletuner.simpletuner_sdk.server.routes.datasets import router as dataset_router from simpletuner.simpletuner_sdk.server.routes.publishing import router as publishing_router from simpletuner.simpletuner_sdk.server.routes.web import router as web_router +from simpletuner.simpletuner_sdk.server.utils.paths import get_config_directory, get_static_directory, get_template_directory from simpletuner.simpletuner_sdk.training_host import TrainingHost @@ -54,9 +55,13 @@ class CancelRequest(BaseModel): allow_headers=["*"], ) -# Mount static files -if os.path.exists("static"): - app.mount("/static", StaticFiles(directory="static"), name="static") +# Mount static files from the package location +static_dir = get_static_directory() +if static_dir.exists(): + app.mount("/static", StaticFiles(directory=str(static_dir)), name="static") + +# Ensure templates resolve to the packaged directory unless overridden +os.environ.setdefault("TEMPLATE_DIR", str(get_template_directory())) # Include web interface router (uses TabService with all tabs) app.include_router(web_router) @@ -108,11 +113,9 @@ def main(): """Main entry point for the server worker.""" import uvicorn - # Create necessary directories - os.makedirs("static/css", exist_ok=True) - os.makedirs("static/js", exist_ok=True) - os.makedirs("templates", exist_ok=True) - os.makedirs("configs", exist_ok=True) + # Ensure configuration directory exists (uses configured/default path) + config_dir = get_config_directory() + os.environ.setdefault("SIMPLETUNER_CONFIG_DIR", str(config_dir)) # Check for SSL configuration ssl_enabled = os.environ.get("SIMPLETUNER_SSL_ENABLED", "false").lower() == "true" diff --git a/simpletuner/simpletuner_sdk/process_keeper.py b/simpletuner/simpletuner_sdk/process_keeper.py index c90897e47..e28885d10 100644 --- a/simpletuner/simpletuner_sdk/process_keeper.py +++ b/simpletuner/simpletuner_sdk/process_keeper.py @@ -15,7 +15,7 @@ import time from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Mapping +from typing import Any, Dict, List, Mapping, Optional try: # Optional dependency; used for robust process tree termination import psutil # type: ignore @@ -47,23 +47,72 @@ def __init__(self, job_id: str): self.output_thread = None self._relayed_failure = False self._relayed_completion = False - # Create temp directory for IPC - self.ipc_dir = tempfile.mkdtemp(prefix=f"trainer_{job_id}_") + # IPC paths are initialized when the subprocess starts (needs config) + self.ipc_dir: Optional[str] = None + self.command_file: Optional[str] = None + self.event_file: Optional[str] = None + self.func_file: Optional[str] = None + + def _resolve_runtime_base(self, config: Optional[Any]) -> Path: + """Determine a writable base directory for IPC files.""" + candidates: List[Path] = [] + + def _coerce_path(value: Optional[Any]) -> Optional[Path]: + if not value: + return None + try: + return Path(str(value)).expanduser().resolve() + except Exception: + return None + + output_dir = None + if isinstance(config, Mapping): + output_dir = config.get("output_dir") + else: + output_dir = getattr(config, "output_dir", None) + output_path = _coerce_path(output_dir) + if output_path is not None: + candidates.append(output_path / ".simpletuner_runtime") + + env_runtime = _coerce_path(os.environ.get("SIMPLETUNER_RUNTIME_DIR")) + if env_runtime is not None: + candidates.append(env_runtime) + + # Always fall back to the system temporary directory + candidates.append(Path(tempfile.gettempdir())) + + for base in candidates: + try: + base.mkdir(parents=True, exist_ok=True) + ipc_path = Path(tempfile.mkdtemp(prefix=f"trainer_{self.job_id}_", dir=str(base))) + return ipc_path + except Exception as exc: # pragma: no cover - best effort, continue to fallback + logger.debug(f"Failed to create IPC dir in {base}: {exc}") + + # Final fallback - let mkdtemp choose location + return Path(tempfile.mkdtemp(prefix=f"trainer_{self.job_id}_")) + + def _initialize_ipc_paths(self, config: Optional[Any]) -> None: + if self.ipc_dir is not None: + return + + ipc_path = self._resolve_runtime_base(config) + self.ipc_dir = str(ipc_path) logger.info(f"IPC dir {self.ipc_dir}") self.command_file = os.path.join(self.ipc_dir, "commands.json") self.event_file = os.path.join(self.ipc_dir, "events.json") self.func_file = os.path.join(self.ipc_dir, "func.pkl") - # Initialize command file + # Initialize command and event files with open(self.command_file, "w") as f: json.dump([], f) - - # Initialize event file with open(self.event_file, "w") as f: json.dump([], f) def start(self, target_func, config: Dict[str, Any]): """Start the trainer subprocess.""" + self._initialize_ipc_paths(config) + # Get function module and name for import func_module = target_func.__module__ func_name = target_func.__name__ @@ -366,12 +415,12 @@ def _handle_event(self, event: Dict[str, Any]): # Update status from events if event.get("type") == "state": - state_data = event.get("data", {}) - if "status" in state_data: - # Don't override terminated status with aborting - if not (self.status == "terminated" and state_data["status"] == "aborting"): - self.status = state_data["status"] - process_registry[self.job_id]["status"] = self.status + state_data = event.get("data", {}) or {} + if "status" in state_data: + # Don't override terminated status with aborting + if not (self.status == "terminated" and state_data["status"] == "aborting"): + self.status = state_data["status"] + process_registry[self.job_id]["status"] = self.status event_type = str(event.get("type") or "").lower() event_data = event.get("data") or {} @@ -560,7 +609,9 @@ def _dispatch_callback_event(self, payload: Dict[str, Any]) -> None: except Exception: logger.debug("Failed to relay subprocess event to callback service", exc_info=True) - def _dispatch_training_status_event(self, *, status: str, data: Optional[Dict[str, Any]] = None, message: str = "") -> None: + def _dispatch_training_status_event( + self, *, status: str, data: Optional[Dict[str, Any]] = None, message: str = "" + ) -> None: normalized = status.strip().lower() status_payload = { "type": "training.status", @@ -636,7 +687,7 @@ def _cleanup_resources(self): try: import shutil - if os.path.exists(self.ipc_dir): + if self.ipc_dir and os.path.exists(self.ipc_dir): shutil.rmtree(self.ipc_dir) except Exception as e: logger.debug(f"Failed to clean up IPC dir: {e}") diff --git a/simpletuner/simpletuner_sdk/server/routes/events.py b/simpletuner/simpletuner_sdk/server/routes/events.py index 43e79c5b7..16e5eaf2e 100644 --- a/simpletuner/simpletuner_sdk/server/routes/events.py +++ b/simpletuner/simpletuner_sdk/server/routes/events.py @@ -119,7 +119,7 @@ async def handle_callback(request: Request): event = callback_service.handle_incoming(data) safe_raw = _truncate_long_strings(data) - logger.info("Received callback: %s", safe_raw) + logger.debug("Received callback: %s", safe_raw) if event: logger.debug("Normalised callback: %s", _truncate_long_strings(event.to_payload())) diff --git a/simpletuner/simpletuner_sdk/server/services/configs_service.py b/simpletuner/simpletuner_sdk/server/services/configs_service.py index 245fc222c..c4a92dcb4 100644 --- a/simpletuner/simpletuner_sdk/server/services/configs_service.py +++ b/simpletuner/simpletuner_sdk/server/services/configs_service.py @@ -1204,6 +1204,13 @@ def normalize_form_to_config( except json.JSONDecodeError: # Keep original string if parsing fails; validation will surface errors later pass + if config_key in {"--tread_config", "tread_config"} and isinstance(converted_value, str): + trimmed = converted_value.strip() + if trimmed.startswith("{") or trimmed.startswith("["): + try: + converted_value = json.loads(trimmed) + except json.JSONDecodeError: + pass config_dict[config_key] = converted_value diff --git a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/model.py b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/model.py index 469ace85b..925128a33 100644 --- a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/model.py +++ b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/model.py @@ -506,6 +506,24 @@ def _quant_label(value: str) -> str: ) ) + registry._add_field( + ConfigField( + name="wan_force_2_1_time_embedding", + arg_name="--wan_force_2_1_time_embedding", + ui_label="Force Wan 2.1 Time Embedding", + field_type=FieldType.CHECKBOX, + tab="model", + section="model_config", + subsection="wan_specific", + default_value=False, + dependencies=[FieldDependency(field="model_family", operator="equals", value="wan", action="show")], + help_text="Use Wan 2.1 style time embeddings even when running Wan 2.2 checkpoints.", + tooltip="Enable this if Wan 2.2 checkpoints report shape mismatches in the time embedding layers.", + importance=ImportanceLevel.ADVANCED, + order=30, + ) + ) + # Fused QKV Projections registry._add_field( ConfigField( diff --git a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/training.py b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/training.py index c66d9c913..77637cfca 100644 --- a/simpletuner/simpletuner_sdk/server/services/field_registry/sections/training.py +++ b/simpletuner/simpletuner_sdk/server/services/field_registry/sections/training.py @@ -221,6 +221,100 @@ def register_training_fields(registry: "FieldRegistry") -> None: ) ) + # Group Offloading + registry._add_field( + ConfigField( + name="enable_group_offload", + arg_name="--enable_group_offload", + ui_label="Enable Group Offloading", + field_type=FieldType.CHECKBOX, + tab="training", + section="memory_optimization", + default_value=False, + help_text="Offload groups of layers to CPU (or disk) between forward passes to reduce VRAM.", + tooltip="Useful when training large models on limited VRAM. May slow training slightly depending on hardware.", + importance=ImportanceLevel.ADVANCED, + order=3, + ) + ) + + registry._add_field( + ConfigField( + name="group_offload_type", + arg_name="--group_offload_type", + ui_label="Group Offload Granularity", + field_type=FieldType.SELECT, + tab="training", + section="memory_optimization", + default_value="block_level", + choices=[ + {"value": "block_level", "label": "Block level (balanced)"}, + {"value": "leaf_level", "label": "Layer level (max savings)"}, + ], + help_text="Choose how modules are grouped when offloading.", + tooltip="Block level transfers multiple layers together for better throughput. Leaf level maximises memory savings.", + importance=ImportanceLevel.ADVANCED, + order=4, + dependencies=[FieldDependency(field="enable_group_offload", operator="equals", value=True, action="show")], + ) + ) + + registry._add_field( + ConfigField( + name="group_offload_blocks_per_group", + arg_name="--group_offload_blocks_per_group", + ui_label="Blocks per Group", + field_type=FieldType.NUMBER, + tab="training", + section="memory_optimization", + default_value=1, + validation_rules=[ValidationRule(ValidationRuleType.MIN, value=1, message="Must be at least 1 block")], + help_text="Number of transformer blocks to bundle when using block-level offloading.", + tooltip="Higher values reduce CPU transfers but increase VRAM usage.", + importance=ImportanceLevel.ADVANCED, + order=5, + dependencies=[ + FieldDependency(field="enable_group_offload", operator="equals", value=True, action="show"), + FieldDependency(field="group_offload_type", operator="equals", value="block_level", action="enable"), + ], + ) + ) + + registry._add_field( + ConfigField( + name="group_offload_use_stream", + arg_name="--group_offload_use_stream", + ui_label="Use CUDA Streams for Offload", + field_type=FieldType.CHECKBOX, + tab="training", + section="memory_optimization", + default_value=False, + help_text="Overlap data transfers with compute using CUDA streams (only available on CUDA devices).", + tooltip="Recommended when training on GPUs with CUDA; automatically disabled on other backends.", + importance=ImportanceLevel.ADVANCED, + order=6, + dependencies=[FieldDependency(field="enable_group_offload", operator="equals", value=True, action="show")], + ) + ) + + registry._add_field( + ConfigField( + name="group_offload_to_disk_path", + arg_name="--group_offload_to_disk_path", + ui_label="Group Offload Disk Path", + field_type=FieldType.TEXT, + tab="training", + section="memory_optimization", + default_value="", + placeholder="/tmp/simpletuner-offload", + help_text="Optional directory to spill parameters when offloading (useful on memory-constrained hosts).", + tooltip="Leave empty to keep offloaded weights in RAM. Directory is created if it does not exist.", + importance=ImportanceLevel.ADVANCED, + order=7, + dependencies=[FieldDependency(field="enable_group_offload", operator="equals", value=True, action="show")], + ) + ) + # Train Text Encoder registry._add_field( ConfigField( diff --git a/simpletuner/simpletuner_sdk/server/services/models_service.py b/simpletuner/simpletuner_sdk/server/services/models_service.py index 6a2ed6f28..319ca815d 100644 --- a/simpletuner/simpletuner_sdk/server/services/models_service.py +++ b/simpletuner/simpletuner_sdk/server/services/models_service.py @@ -9,7 +9,7 @@ from fastapi import status -from simpletuner.helpers.models.common import ModelFoundation, PipelineTypes +from simpletuner.helpers.models.common import ModelFoundation, PipelineTypes, VideoModelFoundation from simpletuner.helpers.models.registry import ModelRegistry @@ -134,6 +134,9 @@ def get_model_details(self, model_family: str) -> Dict[str, Any]: "overrides_requires_conditioning_latents": self._is_method_overridden( model_cls, "requires_conditioning_latents" ), + "overrides_requires_conditioning_image_embeds": self._is_method_overridden( + model_cls, "requires_conditioning_image_embeds" + ), "overrides_requires_conditioning_validation_inputs": self._is_method_overridden( model_cls, "requires_conditioning_validation_inputs" ), @@ -146,6 +149,7 @@ def get_model_details(self, model_family: str) -> Dict[str, Any]: "has_controlnet_pipeline": any( pt in pipeline_types for pt in {PipelineTypes.CONTROLNET.value, PipelineTypes.CONTROL.value} ), + "is_video_model": issubclass(model_cls, VideoModelFoundation), } default_flavour = getattr(model_cls, "DEFAULT_MODEL_FLAVOUR", None) @@ -219,6 +223,7 @@ def _safe_call(attr: str, default): requires_dataset = bool(_safe_call("requires_conditioning_dataset", False)) requires_latents = bool(_safe_call("requires_conditioning_latents", False)) + requires_image_embeds = bool(_safe_call("requires_conditioning_image_embeds", False)) requires_validation_inputs = bool(_safe_call("requires_conditioning_validation_inputs", False)) requires_edit_captions = bool(_safe_call("requires_validation_edit_captions", False)) dataset_type = _safe_call("conditioning_validation_dataset_type", "conditioning") @@ -228,6 +233,7 @@ def _safe_call(attr: str, default): return { "requires_conditioning_dataset": requires_dataset, "requires_conditioning_latents": requires_latents, + "requires_conditioning_image_embeds": requires_image_embeds, "requires_conditioning_validation_inputs": requires_validation_inputs, "requires_validation_edit_captions": requires_edit_captions, "conditioning_dataset_type": dataset_type, diff --git a/simpletuner/simpletuner_sdk/server/services/system_status_service.py b/simpletuner/simpletuner_sdk/server/services/system_status_service.py index ffcd1d7af..5ffeab024 100644 --- a/simpletuner/simpletuner_sdk/server/services/system_status_service.py +++ b/simpletuner/simpletuner_sdk/server/services/system_status_service.py @@ -28,6 +28,11 @@ except Exception: # pragma: no cover - optional dependency in CPU-only environments torch = None # type: ignore +try: # nvidia-ml-py exposes the pynvml module + import pynvml # type: ignore +except Exception: # pragma: no cover - optional dependency + pynvml = None # type: ignore + class SystemStatusService: """Expose basic system statistics for display in the Web UI.""" @@ -132,8 +137,6 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]: memory_percent = fallback_entry.get("memory_percent") utilisation = nvidia_fallback[target_idx] if utilisation is None and backend == "rocm": - if rocm_fallback is None: - rocm_fallback = self._get_rocm_gpu_utilisation() if rocm_fallback: target_idx = None if isinstance(index, int) and 0 <= index < len(rocm_fallback): @@ -326,136 +329,94 @@ def _measure_directory(self, path: str) -> Tuple[Optional[int], Optional[int]]: } return size, count - def _get_rocm_gpu_utilisation(self) -> Optional[List[Optional[float]]]: - rocm_smi = shutil.which("rocm-smi") - if not rocm_smi: - for candidate in ("/opt/rocm/bin/rocm-smi", "/usr/bin/rocm-smi"): - if os.path.isfile(candidate) and os.access(candidate, os.X_OK): - rocm_smi = candidate - break - if not rocm_smi: + def _get_nvidia_gpu_stats(self) -> Optional[List[Dict[str, Optional[float]]]]: + if platform.system() == "Darwin": return None - commands = [ - [rocm_smi, "--showuse", "--json"], - [rocm_smi, "--showuse"], - ] - - for command in commands: - try: - completed = subprocess.run( - command, - check=True, - capture_output=True, - text=True, - timeout=2, - ) - except (FileNotFoundError, subprocess.SubprocessError) as exc: - logger.debug("Unable to query GPU utilisation via rocm-smi (%s): %s", " ".join(command[1:]), exc, exc_info=True) - continue + nvml_stats = self._get_nvml_gpu_stats() + if nvml_stats: + return nvml_stats - output = (completed.stdout or "").strip() - if not output: - continue + return self._get_nvidia_smi_stats() - if "--json" in command: - try: - data = json.loads(output) - except json.JSONDecodeError as exc: - logger.debug("Failed to parse rocm-smi JSON output: %s", exc, exc_info=True) - continue + def _get_nvml_gpu_stats(self) -> Optional[List[Dict[str, Optional[float]]]]: + if pynvml is None: + return None - if isinstance(data, dict): - values: Dict[int, Optional[float]] = {} - for key, entry in data.items(): - if not isinstance(entry, dict): - continue - gpu_idx = self._coerce_int(key.lstrip("card")) if isinstance(key, str) else self._coerce_int(key) - raw_value = ( - entry.get("GPU use (%)") - or entry.get("GPU use (%) (avg)") - or entry.get("GPU use (%) (average)") - or entry.get("GPU use (%) (peak)") - or entry.get("GPU use (%) (current)") - ) - values[gpu_idx] = self._coerce_percent(raw_value) - if values: - ordered = [values.get(idx, None) for idx in sorted(values.keys())] - return ordered or None - continue + initialised_here = False + try: + pynvml.nvmlInit() # type: ignore[attr-defined] + initialised_here = True + except Exception as exc: # pragma: no cover - NVML optional + try: + already_init_cls = getattr(pynvml, "NVMLError_AlreadyInitialized", None) # type: ignore[attr-defined] + if already_init_cls and isinstance(exc, already_init_cls): + initialised_here = False + else: + logger.debug("Unable to initialise NVML: %s", exc, exc_info=True) + return None + except Exception: + logger.debug("Unable to initialise NVML: %s", exc, exc_info=True) + return None - parsed_values = self._parse_rocm_smi_text(output) - if parsed_values: - return parsed_values + try: + try: + device_count = pynvml.nvmlDeviceGetCount() # type: ignore[attr-defined] + except Exception as exc: + logger.debug("Failed to query NVML device count: %s", exc, exc_info=True) + return None - return None + if device_count <= 0: + return None - def _parse_rocm_smi_text(self, output: str) -> Optional[List[Optional[float]]]: - pattern = re.compile(r"GPU\s*\[\s*(\d+)\s*\].*?GPU use.*?:\s*([0-9]+(?:\.[0-9]+)?)") - matches = pattern.findall(output) - values: Dict[int, Optional[float]] = {} - if matches: - for gpu, raw in matches: - idx = self._coerce_int(gpu) - if idx is None or idx in values: - continue - values[idx] = self._coerce_percent(raw) - else: - lines = output.splitlines() - header_index = None - headers = [] - for line in lines: - if "GPU use" in line and header_index is None: - headers = [token for token in line.strip().split() if token] - if headers: - try: - header_index = headers.index("use") - 1 if "use" in headers else None - except ValueError: - header_index = None - continue - if not line or not line.strip(): - continue - tokens = [token for token in line.strip().split() if token] - if len(tokens) < 2: - continue - gpu_token = tokens[0] - gpu_idx = self._coerce_int(gpu_token.strip("GPU[]")) - if gpu_idx is None or gpu_idx in values: + stats: List[Dict[str, Optional[float]]] = [] + for index in range(device_count): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(index) # type: ignore[attr-defined] + except Exception as exc: + logger.debug("Failed to acquire NVML handle for device %s: %s", index, exc, exc_info=True) + stats.append({"utilization_percent": None, "memory_percent": None}) continue - percent_value = None - for token in tokens[1:]: - maybe_percent = self._coerce_percent(token) - if maybe_percent is not None: - percent_value = maybe_percent - break - values[gpu_idx] = percent_value - - if not values: - return None - ordered = [values.get(idx, None) for idx in sorted(values.keys())] - return ordered or None + utilisation_value: Optional[float] = None + memory_percent: Optional[float] = None - @staticmethod - def _coerce_percent(value: Any) -> Optional[float]: - if isinstance(value, (int, float)): - return round(float(value), 1) - if isinstance(value, str): - text = value.strip() - if not text: - return None - if text.endswith("%"): - text = text[:-1] - try: - return round(float(text), 1) - except ValueError: - return None - return None + try: + utilisation = pynvml.nvmlDeviceGetUtilizationRates(handle) # type: ignore[attr-defined] + gpu_util = getattr(utilisation, "gpu", None) + if gpu_util is not None: + utilisation_value = float(gpu_util) + except Exception as exc: + logger.debug("Failed to read NVML utilisation for device %s: %s", index, exc, exc_info=True) - def _get_nvidia_gpu_utilisation(self) -> Optional[List[Optional[float]]]: - if platform.system() == "Darwin": - return None + try: + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) # type: ignore[attr-defined] + total_raw = getattr(mem_info, "total", None) + used_raw = getattr(mem_info, "used", None) + if total_raw not in (None, 0): + total = float(total_raw) + used = float(used_raw or 0.0) + if total > 0: + memory_percent = (used / total) * 100.0 + except Exception as exc: + logger.debug("Failed to read NVML memory info for device %s: %s", index, exc, exc_info=True) + + stats.append( + { + "utilization_percent": round(utilisation_value, 1) if utilisation_value is not None else None, + "memory_percent": round(memory_percent, 1) if memory_percent is not None else None, + } + ) + + return stats or None + finally: + if initialised_here: + try: + pynvml.nvmlShutdown() # type: ignore[attr-defined] + except Exception as exc: + logger.debug("Failed to shutdown NVML cleanly: %s", exc, exc_info=True) + def _get_nvidia_smi_stats(self) -> Optional[List[Dict[str, Optional[float]]]]: try: completed = subprocess.run( [ diff --git a/simpletuner/simpletuner_sdk/server/services/training_service.py b/simpletuner/simpletuner_sdk/server/services/training_service.py index 9fecb0089..b2e60664c 100644 --- a/simpletuner/simpletuner_sdk/server/services/training_service.py +++ b/simpletuner/simpletuner_sdk/server/services/training_service.py @@ -3,12 +3,16 @@ from __future__ import annotations import copy +import json import logging import os import re +import shutil +import tempfile import uuid from dataclasses import dataclass, field from datetime import datetime +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from simpletuner.helpers.training.deepspeed_optimizers import DEFAULT_OPTIMIZER as DS_DEFAULT_OPTIMIZER @@ -23,12 +27,79 @@ from simpletuner.simpletuner_sdk.server.services.field_registry_wrapper import lazy_field_registry from simpletuner.simpletuner_sdk.server.services.hardware_service import detect_gpu_inventory from simpletuner.simpletuner_sdk.server.services.webui_state import WebUIDefaults, WebUIStateStore +from simpletuner.simpletuner_sdk.server.utils.paths import resolve_config_path from .webhook_defaults import DEFAULT_CALLBACK_URL, DEFAULT_WEBHOOK_CONFIG logger = logging.getLogger(__name__) +_PROMPT_LIBRARY_RUNTIME_ROOT = Path(tempfile.gettempdir()) / "simpletuner_prompt_libraries" + + +def _ensure_prompt_library_runtime_dir(job_id: str) -> Path: + """Return a clean runtime directory for prompt libraries for a given job.""" + + _PROMPT_LIBRARY_RUNTIME_ROOT.mkdir(parents=True, exist_ok=True) + job_dir = _PROMPT_LIBRARY_RUNTIME_ROOT / job_id + if job_dir.exists(): + shutil.rmtree(job_dir, ignore_errors=True) + job_dir.mkdir(parents=True, exist_ok=True) + return job_dir + + +def _normalise_prompt_library_path(value: Any) -> Optional[str]: + """Return a cleaned string path for the prompt library CLI argument.""" + + if isinstance(value, str): + trimmed = value.strip() + if trimmed and trimmed.lower() not in {"none", "null", "false"}: + return trimmed + return None + + +def _prepare_user_prompt_library( + runtime_payload: Dict[str, Any], + *, + job_id: str, + configs_dir: Optional[str], +) -> None: + """Copy or materialise the user prompt library into a job-scoped location.""" + + inline_library = runtime_payload.get("user_prompt_library") + library_is_inline = isinstance(inline_library, dict) + + cli_path = _normalise_prompt_library_path(runtime_payload.get("--user_prompt_library")) + alias_path = None + if isinstance(inline_library, str): + alias_path = _normalise_prompt_library_path(inline_library) + + if not library_is_inline and not cli_path and not alias_path: + # No prompt library configured + return + + if library_is_inline: + source_path = None + else: + candidate = cli_path or alias_path + resolved = resolve_config_path(candidate, config_dir=configs_dir) if candidate else None + if resolved is None or not resolved.exists(): + raise FileNotFoundError(f"User prompt library not found at '{candidate}'. Provide a valid JSON file.") + source_path = resolved + + job_dir = _ensure_prompt_library_runtime_dir(job_id) + if library_is_inline: + target_path = job_dir / "user_prompt_library.json" + with target_path.open("w", encoding="utf-8") as handle: + json.dump(inline_library, handle, indent=4) + else: + target_path = job_dir / source_path.name + shutil.copy2(source_path, target_path) + + runtime_payload["--user_prompt_library"] = str(target_path) + runtime_payload["user_prompt_library"] = str(target_path) + + @dataclass class TrainingConfigBundle: """Container describing the resolved configuration artefacts for a form post.""" @@ -580,9 +651,17 @@ def _get_with_alias(source: Dict[str, Any], arg: str) -> Any: arg_lookup = key if key.startswith("--") else f"--{clean_key}" is_required_field = _is_required_field(arg_lookup) + explicit_override = ( + arg_lookup in config_dict + or clean_key in config_dict + or arg_lookup.lstrip("-") in config_dict + or clean_key in form_dict + or arg_lookup in form_dict + ) + if save_options.get("preserve_defaults", False) and not is_required_field: default_value = all_defaults.get(arg_lookup, all_defaults.get(key)) - if value != default_value: + if value != default_value or explicit_override: save_config[clean_key] = value else: save_config[clean_key] = value @@ -746,6 +825,22 @@ def _read_required(key: str) -> Any: if not value: errors.append(message) + inline_prompt_library = _read_required("user_prompt_library") + if not isinstance(inline_prompt_library, dict): + prompt_library_path = _read_required("--user_prompt_library") + if isinstance(prompt_library_path, str): + candidate = prompt_library_path.strip() + if candidate and candidate.lower() not in {"none", "null", "false"}: + resolved = resolve_config_path( + candidate, + config_dir=getattr(store, "config_dir", None), + check_cwd_first=True, + ) + if resolved is None or not resolved.exists(): + errors.append( + f"User prompt library not found at '{candidate}'. " "Please provide a valid JSON file path." + ) + return TrainingValidationResult( errors=errors, warnings=warnings, @@ -757,16 +852,24 @@ def _read_required(key: str) -> Any: def start_training_job(runtime_config: Dict[str, Any]) -> str: """Submit a training job via the process keeper and return the job identifier.""" + job_id = str(uuid.uuid4())[:8] + runtime_payload = dict(runtime_config) runtime_payload.setdefault("--webhook_config", copy.deepcopy(DEFAULT_WEBHOOK_CONFIG)) + # Resolve the prompt library into a job-scoped path if one was configured. + try: + _, defaults = get_webui_state() + except Exception: + defaults = WebUIDefaults() + configs_dir = getattr(defaults, "configs_dir", None) + _prepare_user_prompt_library(runtime_payload, job_id=job_id, configs_dir=configs_dir) + APIState.set_state("training_config", runtime_payload) APIState.set_state("training_status", "starting") APIState.set_state("training_progress", None) APIState.set_state("training_startup_stages", {}) - job_id = str(uuid.uuid4())[:8] - job_config = dict(runtime_payload) job_config["__job_id__"] = job_id # Ensure the trainer surfaces configuration parsing errors instead of silently diff --git a/simpletuner/simpletuner_sdk/server/services/webui_state.py b/simpletuner/simpletuner_sdk/server/services/webui_state.py index c7c1c88b5..c6c3fab1e 100644 --- a/simpletuner/simpletuner_sdk/server/services/webui_state.py +++ b/simpletuner/simpletuner_sdk/server/services/webui_state.py @@ -184,8 +184,30 @@ def _resolve_base_dir(self) -> Path: if override: return Path(override).expanduser() - base_candidate = os.environ.get(_XDG_HOME_ENV) or os.environ.get(_XDG_CONFIG_HOME_ENV) or str(Path.home()) - return Path(base_candidate).expanduser() / ".simpletuner" / "webui" + base_candidate = os.environ.get(_XDG_HOME_ENV) or os.environ.get(_XDG_CONFIG_HOME_ENV) + if base_candidate: + root = Path(base_candidate).expanduser() + root.mkdir(parents=True, exist_ok=True) + return root / "webui" + + candidate_roots = [] + if Path("/workspace").exists(): + candidate_roots.append(Path("/workspace/simpletuner")) + if Path("/notebooks").exists(): + candidate_roots.append(Path("/notebooks/simpletuner")) + candidate_roots.append(Path.home() / ".simpletuner") + + for root in candidate_roots: + webui_dir = root / "webui" + if webui_dir.exists(): + return webui_dir + + # Nothing pre-existing; create the first preferred root and return its webui directory + preferred_root = candidate_roots[0] + preferred_root.mkdir(parents=True, exist_ok=True) + webui_dir = preferred_root / "webui" + webui_dir.mkdir(parents=True, exist_ok=True) + return webui_dir def _category_path(self, category: str) -> Path: safe_name = category.replace("/", "_") @@ -323,11 +345,12 @@ def save_defaults(self, defaults: WebUIDefaults) -> WebUIDefaults: return defaults def _fallback_paths(self) -> Dict[str, str]: - home_dir = Path.home() + root_dir = self.base_dir.parent + root_dir.mkdir(parents=True, exist_ok=True) return { - "configs_dir": str(home_dir / ".simpletuner" / "configs"), - "output_dir": str(home_dir / ".simpletuner" / "output"), - "datasets_dir": str(home_dir / ".simpletuner" / "datasets"), + "configs_dir": str(root_dir / "configs"), + "output_dir": str(root_dir / "output"), + "datasets_dir": str(root_dir / "datasets"), } def resolve_defaults(self, defaults: WebUIDefaults) -> Dict[str, Any]: diff --git a/simpletuner/simpletuner_sdk/server/utils/paths.py b/simpletuner/simpletuner_sdk/server/utils/paths.py index 63459a5d1..3ec317b09 100644 --- a/simpletuner/simpletuner_sdk/server/utils/paths.py +++ b/simpletuner/simpletuner_sdk/server/utils/paths.py @@ -35,7 +35,31 @@ def get_config_directory() -> Path: Returns: Path to the config directory relative to SimpleTuner root """ - return get_simpletuner_root() / "config" + env_override = os.environ.get("SIMPLETUNER_CONFIG_DIR") + if env_override: + return Path(env_override).expanduser() + + candidate_roots = [] + if Path("/workspace").exists(): + candidate_roots.append(Path("/workspace/simpletuner")) + if Path("/notebooks").exists(): + candidate_roots.append(Path("/notebooks/simpletuner")) + candidate_roots.append(Path.home() / ".simpletuner") + + for root in candidate_roots: + candidate = root / "config" + if candidate.exists(): + return candidate + + if candidate_roots: + preferred = candidate_roots[0] / "config" + preferred.mkdir(parents=True, exist_ok=True) + return preferred + + # Fall back to project/package config directory + default_dir = get_simpletuner_root() / "config" + default_dir.mkdir(parents=True, exist_ok=True) + return default_dir def get_template_directory() -> Path: diff --git a/simpletuner/static/js/dataset-wizard.js b/simpletuner/static/js/dataset-wizard.js index 6a4af2b94..6bfe79266 100644 --- a/simpletuner/static/js/dataset-wizard.js +++ b/simpletuner/static/js/dataset-wizard.js @@ -59,6 +59,7 @@ }, conditioningGenerators: CONDITIONING_GENERATOR_TYPES, newAspectBucket: null, + videoWarningAcknowledged: false, // Separate cache dataset configs textEmbedsDataset: { @@ -347,6 +348,24 @@ }; }, + selectDatasetType(type) { + if (!type) { + return; + } + if (type === 'video' && !this.isVideoModel && !this.videoWarningAcknowledged) { + window.showToast('Current model is not marked as video-capable. Proceed only if you expect to train with video datasets.', 'warning'); + this.videoWarningAcknowledged = true; + } + this.currentDataset.dataset_type = type; + if (this.selectedBlueprint && !this.selectedBlueprint.datasetTypes.includes(type)) { + this.selectedBlueprint = null; + this.selectedBackend = null; + } + if (type !== 'conditioning') { + this.conditioningConfigured = false; + } + }, + selectBackend(backendType) { this.selectedBackend = backendType; this.currentDataset.type = backendType; diff --git a/simpletuner/static/js/modules/trainer-main.js b/simpletuner/static/js/modules/trainer-main.js index 59977813b..7fd24a836 100644 --- a/simpletuner/static/js/modules/trainer-main.js +++ b/simpletuner/static/js/modules/trainer-main.js @@ -438,9 +438,12 @@ class TrainerMain { return { reset: true }; } const percentValue = Number(progress.percent || progress.percentage || 0); - const clampedPercent = Number.isFinite(percentValue) ? Math.max(0, Math.min(100, percentValue)) : 0; + const clampedPercent = Number.isFinite(percentValue) + ? Math.max(0, Math.min(100, percentValue)) + : 0; + const roundedPercent = Math.round(clampedPercent * 100) / 100; return { - percent: clampedPercent, + percent: roundedPercent, step: progress.step || progress.current_step || 0, total_steps: progress.total_steps || progress.total || 0, epoch: progress.epoch || 0, diff --git a/simpletuner/static/js/sse-manager.js b/simpletuner/static/js/sse-manager.js index d7a6d0922..a285ac486 100644 --- a/simpletuner/static/js/sse-manager.js +++ b/simpletuner/static/js/sse-manager.js @@ -188,6 +188,8 @@ if (!Number.isFinite(percent)) { percent = 0; } + var clampedPercent = Math.max(0, Math.min(100, percent)); + var roundedPercent = Math.round(clampedPercent * 100) / 100; var epoch = toNumber(extras.epoch || extras.current_epoch); if (epoch === null && payload.current_epoch !== undefined) { @@ -214,7 +216,7 @@ return { type: 'training.progress', job_id: jobId, - percentage: Number(percent || 0), + percentage: roundedPercent, current_step: currentStep || 0, total_steps: totalSteps || 0, epoch: epoch || 0, diff --git a/simpletuner/static/js/training-wizard.js b/simpletuner/static/js/training-wizard.js index 28459ed99..aa0bfc426 100644 --- a/simpletuner/static/js/training-wizard.js +++ b/simpletuner/static/js/training-wizard.js @@ -55,6 +55,7 @@ function trainingWizardComponent() { model_family: null, model_flavour: null, model_type: 'lora', // Default to LoRA + full_training_strategy: 'deepspeed', training_length_mode: 'epochs', num_train_epochs: 1, max_train_steps: 0, @@ -91,10 +92,24 @@ function trainingWizardComponent() { deepspeed_offload_path: '', deepspeed_zero3_init: false, deepspeed_config: null, + enable_group_offload: false, + group_offload_type: 'block_level', + group_offload_blocks_per_group: 1, + group_offload_use_stream: false, + group_offload_to_disk_path: '', + fsdp_enable: false, + fsdp_version: 2, + fsdp_reshard_after_forward: true, + fsdp_state_dict_type: 'SHARDED_STATE_DICT', + fsdp_cpu_ram_efficient_loading: false, + fsdp_auto_wrap_policy: 'TRANSFORMER_BASED_WRAP', + fsdp_transformer_layer_cls_to_wrap: '', + context_parallel_size: 1, createNewDataset: false // Track whether user chose to create new dataset }, uiOnlyAnswerKeys: [ 'createNewDataset', + 'full_training_strategy', 'deepspeed_preset', 'deepspeed_offload_param', 'deepspeed_offload_optimizer', @@ -468,21 +483,7 @@ function trainingWizardComponent() { this.answers.model_card_private = modelCardPrivate === true || modelCardPrivate === 'true' || modelCardPrivate === '1'; } - const rawDeepSpeedConfig = config.deepspeed_config ?? config['--deepspeed_config']; - if (rawDeepSpeedConfig !== undefined) { - this.inferDeepSpeedFromConfig(rawDeepSpeedConfig); - } else { - this.inferDeepSpeedFromConfig(null); - } - - const offloadPathValue = config.offload_param_path ?? config['--offload_param_path']; - if (offloadPathValue !== undefined && offloadPathValue !== null && String(offloadPathValue).trim() !== '') { - const normalizedPath = String(offloadPathValue).trim(); - this.answers.offload_param_path = normalizedPath; - if (!this.answers.deepspeed_offload_path) { - this.answers.deepspeed_offload_path = normalizedPath; - } - } + this.applyAccelerationFromConfig(config); console.log('[TRAINING WIZARD] Loaded current config:', this.answers); } @@ -751,6 +752,9 @@ function trainingWizardComponent() { this.answers.deepspeed_config = null; this.answers.offload_param_path = null; this.deepspeedBaseConfig = null; + this.resetDeepSpeedState(); + this.clearGroupOffloadState(); + this.clearFsdpState(); } else { this.answers.base_model_precision = 'no_change'; this.answers.quantize_via = 'accelerator'; @@ -758,6 +762,10 @@ function trainingWizardComponent() { this.answers[`text_encoder_${i}_precision`] = 'no_change'; } await this.loadQuantizationOptions(); + const previousStrategy = this.answers.full_training_strategy && this.answers.full_training_strategy !== 'none' + ? this.answers.full_training_strategy + : 'deepspeed'; + this.selectFullTrainingStrategy(previousStrategy); } } @@ -1237,6 +1245,301 @@ function trainingWizardComponent() { console.log('[TRAINING WIZARD] Answers applied to all trainer store locations'); }, + coerceBoolean(value) { + if (typeof value === 'boolean') { + return value; + } + if (typeof value === 'number') { + return value !== 0; + } + if (typeof value === 'string') { + const normalized = value.trim().toLowerCase(); + if (!normalized) { + return null; + } + if (['true', '1', 'yes', 'on'].includes(normalized)) { + return true; + } + if (['false', '0', 'no', 'off'].includes(normalized)) { + return false; + } + } + return null; + }, + + coerceNumber(value) { + if (value === undefined || value === null || value === '') { + return null; + } + const parsed = Number(value); + return Number.isFinite(parsed) ? parsed : null; + }, + + coerceString(value) { + if (value === undefined || value === null) { + return null; + } + return String(value); + }, + + selectFullTrainingStrategy(strategy) { + const validStrategies = ['none', 'group_offload', 'deepspeed', 'fsdp2']; + const nextStrategy = validStrategies.includes(strategy) ? strategy : 'none'; + + if (this.answers.full_training_strategy === nextStrategy) { + if (nextStrategy === 'group_offload') { + this.ensureGroupOffloadDefaults(); + } else if (nextStrategy === 'deepspeed') { + this.ensureDeepSpeedDefaults(); + } else if (nextStrategy === 'fsdp2') { + this.ensureFsdpDefaults(); + } + this.updateDeepSpeedConfig(); + return; + } + + this.answers.full_training_strategy = nextStrategy; + + switch (nextStrategy) { + case 'group_offload': + this.resetDeepSpeedState(); + this.ensureGroupOffloadDefaults(); + this.clearFsdpState(); + break; + case 'deepspeed': + this.clearGroupOffloadState(); + this.ensureDeepSpeedDefaults(); + this.clearFsdpState(); + break; + case 'fsdp2': + this.resetDeepSpeedState(); + this.clearGroupOffloadState(); + this.ensureFsdpDefaults(); + break; + default: + this.resetDeepSpeedState(); + this.clearGroupOffloadState(); + this.clearFsdpState(); + break; + } + + this.updateDeepSpeedConfig(); + }, + + ensureDeepSpeedDefaults() { + if (typeof this.answers.deepspeed_preset !== 'string') { + this.answers.deepspeed_preset = 'disabled'; + } + if (typeof this.answers.deepspeed_offload_param !== 'string') { + this.answers.deepspeed_offload_param = 'none'; + } + if (typeof this.answers.deepspeed_offload_optimizer !== 'string') { + this.answers.deepspeed_offload_optimizer = 'none'; + } + if (typeof this.answers.deepspeed_offload_path !== 'string') { + this.answers.deepspeed_offload_path = ''; + } + if (typeof this.answers.deepspeed_zero3_init !== 'boolean') { + this.answers.deepspeed_zero3_init = false; + } + }, + + resetDeepSpeedState() { + this.answers.deepspeed_preset = 'disabled'; + this.answers.deepspeed_offload_param = 'none'; + this.answers.deepspeed_offload_optimizer = 'none'; + this.answers.deepspeed_offload_path = ''; + this.answers.deepspeed_zero3_init = false; + this.answers.deepspeed_config = null; + this.answers.offload_param_path = null; + this.deepspeedBaseConfig = null; + this.syncDeepSpeedBuilderField(); + }, + + ensureGroupOffloadDefaults() { + this.answers.enable_group_offload = true; + if (!this.answers.group_offload_type) { + this.answers.group_offload_type = 'block_level'; + } + if (!Number.isFinite(this.answers.group_offload_blocks_per_group) || this.answers.group_offload_blocks_per_group <= 0) { + this.answers.group_offload_blocks_per_group = 1; + } + if (typeof this.answers.group_offload_use_stream !== 'boolean') { + this.answers.group_offload_use_stream = false; + } + if (typeof this.answers.group_offload_to_disk_path !== 'string') { + this.answers.group_offload_to_disk_path = ''; + } + }, + + clearGroupOffloadState() { + this.answers.enable_group_offload = false; + this.answers.group_offload_type = null; + this.answers.group_offload_blocks_per_group = null; + this.answers.group_offload_use_stream = false; + this.answers.group_offload_to_disk_path = ''; + }, + + ensureFsdpDefaults() { + this.answers.fsdp_enable = true; + if (!Number.isFinite(this.answers.fsdp_version)) { + this.answers.fsdp_version = 2; + } + if (typeof this.answers.fsdp_reshard_after_forward !== 'boolean') { + this.answers.fsdp_reshard_after_forward = true; + } + if (!this.answers.fsdp_state_dict_type) { + this.answers.fsdp_state_dict_type = 'SHARDED_STATE_DICT'; + } + if (typeof this.answers.fsdp_cpu_ram_efficient_loading !== 'boolean') { + this.answers.fsdp_cpu_ram_efficient_loading = false; + } + if (!this.answers.fsdp_auto_wrap_policy) { + this.answers.fsdp_auto_wrap_policy = 'TRANSFORMER_BASED_WRAP'; + } + if (typeof this.answers.fsdp_transformer_layer_cls_to_wrap !== 'string') { + this.answers.fsdp_transformer_layer_cls_to_wrap = ''; + } + if (!Number.isFinite(this.answers.context_parallel_size) || this.answers.context_parallel_size <= 0) { + this.answers.context_parallel_size = 1; + } + }, + + clearFsdpState() { + this.answers.fsdp_enable = false; + this.answers.fsdp_version = null; + this.answers.fsdp_reshard_after_forward = false; + this.answers.fsdp_state_dict_type = null; + this.answers.fsdp_cpu_ram_efficient_loading = false; + this.answers.fsdp_auto_wrap_policy = null; + this.answers.fsdp_transformer_layer_cls_to_wrap = ''; + this.answers.context_parallel_size = null; + }, + + applyAccelerationFromConfig(config) { + const rawDeepSpeedConfig = config.deepspeed_config ?? config['--deepspeed_config']; + const hasDeepSpeed = + rawDeepSpeedConfig !== undefined && + rawDeepSpeedConfig !== null && + (typeof rawDeepSpeedConfig === 'object' || + (typeof rawDeepSpeedConfig === 'string' && rawDeepSpeedConfig.trim().length > 0)); + + const groupOffloadEnabled = this.coerceBoolean( + config.enable_group_offload ?? config['--enable_group_offload'] + ) === true; + const fsdpEnabled = this.coerceBoolean(config.fsdp_enable ?? config['--fsdp_enable']) === true; + + if (this.answers.model_type !== 'full') { + this.resetDeepSpeedState(); + this.clearGroupOffloadState(); + this.clearFsdpState(); + this.answers.full_training_strategy = 'none'; + return; + } + + if (fsdpEnabled) { + this.selectFullTrainingStrategy('fsdp2'); + this.answers.fsdp_enable = true; + + const fsdpVersion = this.coerceNumber(config.fsdp_version ?? config['--fsdp_version']); + if (fsdpVersion) { + this.answers.fsdp_version = fsdpVersion; + } + + const reshard = this.coerceBoolean( + config.fsdp_reshard_after_forward ?? config['--fsdp_reshard_after_forward'] + ); + if (reshard !== null) { + this.answers.fsdp_reshard_after_forward = reshard; + } + + const stateDict = this.coerceString(config.fsdp_state_dict_type ?? config['--fsdp_state_dict_type']); + if (stateDict) { + this.answers.fsdp_state_dict_type = stateDict; + } + + const cpuEfficient = this.coerceBoolean( + config.fsdp_cpu_ram_efficient_loading ?? config['--fsdp_cpu_ram_efficient_loading'] + ); + if (cpuEfficient !== null) { + this.answers.fsdp_cpu_ram_efficient_loading = cpuEfficient; + } + + const autoWrap = this.coerceString( + config.fsdp_auto_wrap_policy ?? config['--fsdp_auto_wrap_policy'] + ); + if (autoWrap) { + this.answers.fsdp_auto_wrap_policy = autoWrap; + } + + const layerClasses = this.coerceString( + config.fsdp_transformer_layer_cls_to_wrap ?? config['--fsdp_transformer_layer_cls_to_wrap'] + ); + if (layerClasses !== null) { + this.answers.fsdp_transformer_layer_cls_to_wrap = layerClasses; + } + + const contextParallel = this.coerceNumber( + config.context_parallel_size ?? config['--context_parallel_size'] + ); + if (contextParallel) { + this.answers.context_parallel_size = contextParallel; + } + + return; + } + + if (groupOffloadEnabled) { + this.selectFullTrainingStrategy('group_offload'); + this.answers.enable_group_offload = true; + + const offloadType = this.coerceString( + config.group_offload_type ?? config['--group_offload_type'] + ); + if (offloadType) { + this.answers.group_offload_type = offloadType; + } + + const blocksPerGroup = this.coerceNumber( + config.group_offload_blocks_per_group ?? config['--group_offload_blocks_per_group'] + ); + if (blocksPerGroup) { + this.answers.group_offload_blocks_per_group = blocksPerGroup; + } + + const useStream = this.coerceBoolean( + config.group_offload_use_stream ?? config['--group_offload_use_stream'] + ); + if (useStream !== null) { + this.answers.group_offload_use_stream = useStream; + } + + const diskPath = this.coerceString( + config.group_offload_to_disk_path ?? config['--group_offload_to_disk_path'] + ); + if (diskPath !== null) { + this.answers.group_offload_to_disk_path = diskPath; + } + + return; + } + + if (hasDeepSpeed) { + this.selectFullTrainingStrategy('deepspeed'); + this.inferDeepSpeedFromConfig(rawDeepSpeedConfig); + + const offloadPathValue = config.offload_param_path ?? config['--offload_param_path']; + if (offloadPathValue !== undefined && offloadPathValue !== null && String(offloadPathValue).trim() !== '') { + const normalizedPath = String(offloadPathValue).trim(); + this.answers.offload_param_path = normalizedPath; + this.answers.deepspeed_offload_path = normalizedPath; + } + return; + } + + this.selectFullTrainingStrategy('none'); + }, + // Field navigation using existing search mechanism wizardNavigateToField(fieldName) { console.log(`[TRAINING WIZARD] Navigating to field: ${fieldName}`); @@ -1600,6 +1903,9 @@ function trainingWizardComponent() { this.inferDeepSpeedFromConfig(null); return; } + if (this.answers.full_training_strategy !== 'deepspeed') { + this.selectFullTrainingStrategy('deepspeed'); + } this.inferDeepSpeedFromConfig(rawValue); } finally { this._handlingDeepSpeedBuilderUpdate = false; @@ -1608,6 +1914,9 @@ function trainingWizardComponent() { }, selectDeepSpeedPreset(preset) { + if (this.answers.full_training_strategy !== 'deepspeed') { + this.selectFullTrainingStrategy('deepspeed'); + } this.answers.deepspeed_preset = preset; if (preset === 'disabled') { @@ -1646,7 +1955,7 @@ function trainingWizardComponent() { }, updateDeepSpeedConfig() { - if (this.answers.model_type !== 'full') { + if (this.answers.model_type !== 'full' || this.answers.full_training_strategy !== 'deepspeed') { this.answers.deepspeed_config = null; this.answers.offload_param_path = null; this.syncDeepSpeedBuilderField(); diff --git a/simpletuner/templates/components/training_events_sse.html b/simpletuner/templates/components/training_events_sse.html index c3886b77f..d64b6e867 100644 --- a/simpletuner/templates/components/training_events_sse.html +++ b/simpletuner/templates/components/training_events_sse.html @@ -577,7 +577,8 @@
Training Progress
if (!Number.isFinite(num)) { return 0; } - return Math.max(0, Math.min(100, num)); + const clamped = Math.max(0, Math.min(100, num)); + return Math.round(clamped * 100) / 100; } function toNumber(value) { diff --git a/simpletuner/templates/partials/dataset_wizard_modal.html b/simpletuner/templates/partials/dataset_wizard_modal.html index c39059a72..86ceb6d2c 100644 --- a/simpletuner/templates/partials/dataset_wizard_modal.html +++ b/simpletuner/templates/partials/dataset_wizard_modal.html @@ -87,7 +87,7 @@
Dataset Type
+ @click="selectDatasetType('image')">
@@ -101,15 +101,14 @@
Image Dataset
'selected': currentDataset.dataset_type === 'video', 'opacity-50': !isVideoModel }" - @click="isVideoModel && (currentDataset.dataset_type = 'video')" - :style="!isVideoModel ? 'cursor: not-allowed;' : ''"> + @click="selectDatasetType('video')">
Video Dataset

For training with video sequences - Not available for this model + Model not marked video-capable; continue with caution.

diff --git a/simpletuner/templates/partials/training_wizard_modal.html b/simpletuner/templates/partials/training_wizard_modal.html index 7f6b15d75..dae0810b0 100644 --- a/simpletuner/templates/partials/training_wizard_modal.html +++ b/simpletuner/templates/partials/training_wizard_modal.html @@ -189,104 +189,264 @@
Quantization
-
DeepSpeed Configuration
+
Acceleration Strategy

- Configure Hugging Face Accelerate's DeepSpeed integration. Choose the ZeRO stage and optional offload targets. Advanced tweaks remain available from the Hardware > Accelerate tab after finishing the wizard. + Choose how SimpleTuner should manage large-model training. You can enable grouped CPU offloading, Hugging Face Accelerate + DeepSpeed, or PyTorch FSDP2 sharding. Hardware > Accelerate retains all advanced switches for later tweaks.

-
+
-
-
- - -
NVMe requires a fast disk path; CPU keeps tensors in host memory.
+
+ + Keep the default single-process Accelerate launch. You can still enable CPU offload or sharding later from Hardware > Accelerate if your run requires it. +
+ +
+
Group Offload Options
+

+ Diffusers’ group offload moves configured module groups to CPU (or disk) between forward passes, freeing VRAM on small GPUs. SimpleTuner wires the correct CLI flags automatically. +

+ +
+
+ + +
Block level keeps multiple layers together for higher throughput; leaf level maximises memory savings.
+
+
+ + +
Only used with block-level grouping. Higher values reduce transfers but require more VRAM.
+
-
- - -
Optimizer offload reduces GPU memory at the cost of host or disk bandwidth.
+ +
+ +
-
-
- - -
Matches --offload_param_path. Leave blank to require a path before training starts.
-
- -
- - +
+ + +
Leave blank to keep tensors in host RAM. Provide a fast NVMe directory when memory is extremely tight.
+
-
-
- +
+
DeepSpeed Configuration
+

+ Configure 🤗 Accelerate’s DeepSpeed integration. Choose the ZeRO stage and optional offload targets. Advanced edits remain available from Hardware > Accelerate after closing the wizard. +

+ +
+ + +
- -

-                        
- The wizard keeps this JSON in sync. Advanced edits can be made directly from the Hardware tab after closing the wizard. + +
+
+ + +
NVMe requires a fast disk path; CPU keeps tensors in host memory.
+
+
+ + +
Optimizer offload reduces GPU memory at the cost of host or disk bandwidth.
+
+
+ +
+ + +
Matches --offload_param_path. Leave blank to require a path before training starts.
+
+ +
+ + +
+ +
+
+ + +
+ +

+                            
+ The wizard keeps this JSON in sync. Advanced edits can be made directly from the Hardware tab after closing the wizard. +
+
+
+ +
+
FSDP2 Configuration
+

+ FSDP2 shards model parameters, optimizer state, and activations across GPUs. SimpleTuner enables Accelerate’s DTensor-backed implementation and exposes the most common toggles here. +

+ +
+
+ + +
Sharded checkpoints save memory by keeping tensors distributed when writing to disk.
+
+
+ + +
Set > 1 to shard attention / context across GPUs. Requires models that support context parallelism.
+
+
+ +
+ + +
+ +
+ + +
+ +
+ + +
Transformer-based wrapping covers most diffusion transformers. Size-based lets you wrap by parameter count.
+
+ +
+ + +
Override Accelerate’s detected layer classes when validation errors request a specific module.
diff --git a/simpletuner/templates/trainer_dataloader_section.html b/simpletuner/templates/trainer_dataloader_section.html index 6ee81151d..786cf0915 100644 --- a/simpletuner/templates/trainer_dataloader_section.html +++ b/simpletuner/templates/trainer_dataloader_section.html @@ -1,5 +1,6 @@
Add Dataset Manually 🖼️ Add Image Dataset