|
1 | 1 | # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team |
2 | 2 | """Default implementation of model loading in InvokeAI.""" |
3 | 3 |
|
| 4 | +import re |
4 | 5 | from logging import Logger |
5 | 6 | from pathlib import Path |
6 | 7 | from typing import Optional |
|
21 | 22 | ) |
22 | 23 | from invokeai.backend.util.devices import TorchDevice |
23 | 24 |
|
| 25 | +# Layer classes that benefit from FP8 storage. Mirrors diffusers' |
| 26 | +# `_GO_LC_SUPPORTED_PYTORCH_LAYERS` so the plain-nn.Module fallback path makes the same |
| 27 | +# precision/quality trade-offs as the ModelMixin path. Notably excludes norm and embedding |
| 28 | +# wrapper modules — those are handled by their direct param types (Embedding is included |
| 29 | +# but pos_embed/patch_embed are filtered by `_FP8_DEFAULT_SKIP_PATTERNS`). |
| 30 | +_FP8_SUPPORTED_PYTORCH_LAYERS: tuple[type[torch.nn.Module], ...] = ( |
| 31 | + torch.nn.Linear, |
| 32 | + torch.nn.Conv1d, |
| 33 | + torch.nn.Conv2d, |
| 34 | + torch.nn.Conv3d, |
| 35 | + torch.nn.ConvTranspose1d, |
| 36 | + torch.nn.ConvTranspose2d, |
| 37 | + torch.nn.ConvTranspose3d, |
| 38 | + torch.nn.Embedding, |
| 39 | +) |
| 40 | + |
| 41 | +# Module-path regexes (matched against `named_modules()` dotted paths) for precision-sensitive |
| 42 | +# layers that should never be cast to FP8. Mirrors diffusers' `DEFAULT_SKIP_MODULES_PATTERN` |
| 43 | +# — without these, FLUX RMSNorm.scale and similar tiny learned scalars get crushed to FP8 and |
| 44 | +# inference quality degrades. Includes anything named `norm`, position/patch embeddings, and |
| 45 | +# the in/out projection of transformer blocks. |
| 46 | +_FP8_DEFAULT_SKIP_PATTERNS: tuple[str, ...] = ( |
| 47 | + "pos_embed", |
| 48 | + "patch_embed", |
| 49 | + "norm", |
| 50 | + r"^proj_in$", |
| 51 | + r"^proj_out$", |
| 52 | +) |
| 53 | + |
24 | 54 |
|
25 | 55 | # TO DO: The loader is not thread safe! |
26 | 56 | class ModelLoader(ModelLoaderBase): |
@@ -124,6 +154,151 @@ def get_size_fs( |
124 | 154 | variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None, |
125 | 155 | ) |
126 | 156 |
|
| 157 | + def _should_use_fp8(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> bool: |
| 158 | + """Check if FP8 layerwise casting should be applied to a model.""" |
| 159 | + # FP8 storage only works on CUDA |
| 160 | + if self._torch_device.type != "cuda": |
| 161 | + return False |
| 162 | + |
| 163 | + # Z-Image has dtype mismatch issues with diffusers' layerwise casting |
| 164 | + # (skipped modules produce bf16, hooked modules expect fp16). |
| 165 | + from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType |
| 166 | + |
| 167 | + if hasattr(config, "base") and config.base == BaseModelType.ZImage: |
| 168 | + return False |
| 169 | + |
| 170 | + # VAEs are excluded — fp8 storage causes noticeable quality degradation in decode. |
| 171 | + if hasattr(config, "type") and config.type == ModelType.VAE: |
| 172 | + return False |
| 173 | + |
| 174 | + # LoRAs (including ControlLoRA) are excluded — they are not run as a standalone forward pass, |
| 175 | + # they are patched into a base model, so the layerwise-casting hooks would never fire. The |
| 176 | + # toggle is also hidden in the UI for ControlLoRA; this guard handles legacy persisted values. |
| 177 | + if hasattr(config, "type") and config.type in (ModelType.LoRA, ModelType.ControlLoRa): |
| 178 | + return False |
| 179 | + |
| 180 | + # Don't apply FP8 to text encoders, tokenizers, schedulers, VAEs, etc. |
| 181 | + _excluded_submodel_types = { |
| 182 | + SubModelType.TextEncoder, |
| 183 | + SubModelType.TextEncoder2, |
| 184 | + SubModelType.TextEncoder3, |
| 185 | + SubModelType.Tokenizer, |
| 186 | + SubModelType.Tokenizer2, |
| 187 | + SubModelType.Tokenizer3, |
| 188 | + SubModelType.Scheduler, |
| 189 | + SubModelType.SafetyChecker, |
| 190 | + SubModelType.VAE, |
| 191 | + SubModelType.VAEDecoder, |
| 192 | + SubModelType.VAEEncoder, |
| 193 | + } |
| 194 | + if submodel_type in _excluded_submodel_types: |
| 195 | + return False |
| 196 | + |
| 197 | + # Check default_settings.fp8_storage (Main models, ControlNet) |
| 198 | + if hasattr(config, "default_settings") and config.default_settings is not None: |
| 199 | + if hasattr(config.default_settings, "fp8_storage") and config.default_settings.fp8_storage is True: |
| 200 | + return True |
| 201 | + |
| 202 | + return False |
| 203 | + |
| 204 | + def _apply_fp8_layerwise_casting( |
| 205 | + self, model: AnyModel, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None |
| 206 | + ) -> AnyModel: |
| 207 | + """Apply FP8 layerwise casting to a model if enabled in its config.""" |
| 208 | + if not self._should_use_fp8(config, submodel_type): |
| 209 | + return model |
| 210 | + |
| 211 | + storage_dtype = torch.float8_e4m3fn |
| 212 | + compute_dtype = self._torch_dtype |
| 213 | + |
| 214 | + # Detect the model's current dtype to use as compute dtype, since models |
| 215 | + # (e.g. Flux) may require a specific dtype (bf16) that differs from the global torch dtype (fp16). |
| 216 | + if isinstance(model, torch.nn.Module): |
| 217 | + first_param = next(model.parameters(), None) |
| 218 | + if first_param is not None: |
| 219 | + compute_dtype = first_param.dtype |
| 220 | + |
| 221 | + from diffusers.models.modeling_utils import ModelMixin |
| 222 | + |
| 223 | + if isinstance(model, ModelMixin): |
| 224 | + model.enable_layerwise_casting( |
| 225 | + storage_dtype=storage_dtype, |
| 226 | + compute_dtype=compute_dtype, |
| 227 | + ) |
| 228 | + elif isinstance(model, torch.nn.Module): |
| 229 | + self._apply_fp8_to_nn_module(model, storage_dtype=storage_dtype, compute_dtype=compute_dtype) |
| 230 | + else: |
| 231 | + return model |
| 232 | + |
| 233 | + param_bytes = sum(p.nelement() * p.element_size() for p in model.parameters()) |
| 234 | + self._logger.info( |
| 235 | + f"FP8 layerwise casting enabled for {config.name} " |
| 236 | + f"(storage=float8_e4m3fn, compute={compute_dtype}, " |
| 237 | + f"param_size={param_bytes / (1024**2):.0f}MB)" |
| 238 | + ) |
| 239 | + return model |
| 240 | + |
| 241 | + @staticmethod |
| 242 | + def _apply_fp8_to_nn_module(model: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None: |
| 243 | + """Apply FP8 layerwise casting to a plain nn.Module. |
| 244 | +
|
| 245 | + Mirrors diffusers' `apply_layerwise_casting` semantics: only the layer classes in |
| 246 | + `_FP8_SUPPORTED_PYTORCH_LAYERS` are cast, and modules whose dotted path matches any of |
| 247 | + `_FP8_DEFAULT_SKIP_PATTERNS` (norm, pos_embed, patch_embed, proj_in/out) are skipped. |
| 248 | + Without the skip list, precision-sensitive tiny learned scalars (e.g. FLUX RMSNorm.scale) |
| 249 | + get crushed to FP8 and quality degrades noticeably. |
| 250 | + """ |
| 251 | + for module_name, module in model.named_modules(): |
| 252 | + if not isinstance(module, _FP8_SUPPORTED_PYTORCH_LAYERS): |
| 253 | + continue |
| 254 | + if any(re.search(pattern, module_name) for pattern in _FP8_DEFAULT_SKIP_PATTERNS): |
| 255 | + continue |
| 256 | + params = list(module.parameters(recurse=False)) |
| 257 | + if not params: |
| 258 | + continue |
| 259 | + |
| 260 | + for param in params: |
| 261 | + param.data = param.data.to(storage_dtype) |
| 262 | + |
| 263 | + ModelLoader._wrap_forward_with_fp8_cast(module, storage_dtype, compute_dtype) |
| 264 | + |
| 265 | + @staticmethod |
| 266 | + def _wrap_forward_with_fp8_cast( |
| 267 | + module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype |
| 268 | + ) -> None: |
| 269 | + """Register pre/post forward hooks that cast params to compute dtype on entry and back |
| 270 | + to storage dtype on exit. |
| 271 | +
|
| 272 | + We use hooks (rather than overriding `module.forward`) for two reasons: |
| 273 | +
|
| 274 | + 1. **Correct dispatch after `apply_custom_layers_to_model`.** `ModelCache.put()` calls |
| 275 | + `apply_custom_layers_to_model`, which creates a NEW `CustomLinear` instance and |
| 276 | + shares the original `Linear.__dict__` (see `wrap_custom_layer`). Anything stored in |
| 277 | + that dict — including an instance-level `forward` attribute — gets carried over to |
| 278 | + the new object. An overridden `forward` would close over the OLD instance, so calls |
| 279 | + to the new `CustomLinear` would silently route to `Linear.forward(old_instance, ...)` |
| 280 | + and bypass the LoRA-patch-aware branch in `CustomLinear.forward`. Hooks, by contrast, |
| 281 | + live in `_forward_hooks` / `_forward_pre_hooks` and are dispatched by |
| 282 | + `nn.Module.__call__` with the *actual* called instance — so they run on the new |
| 283 | + `CustomLinear` and the class's `forward` is still resolved normally. |
| 284 | +
|
| 285 | + 2. **Exception safety.** `register_forward_hook(..., always_call=True)` fires the |
| 286 | + post-hook even when `forward` raises. The plain pre-hook/post-hook pair without |
| 287 | + `always_call` would leave params in compute dtype on exception, defeating FP8 |
| 288 | + storage savings and making cache size accounting stale. |
| 289 | + """ |
| 290 | + |
| 291 | + def pre_hook(mod: torch.nn.Module, _args: object) -> None: |
| 292 | + for p in mod.parameters(recurse=False): |
| 293 | + p.data = p.data.to(compute_dtype) |
| 294 | + |
| 295 | + def post_hook(mod: torch.nn.Module, _args: object, _output: object) -> None: |
| 296 | + for p in mod.parameters(recurse=False): |
| 297 | + p.data = p.data.to(storage_dtype) |
| 298 | + |
| 299 | + module.register_forward_pre_hook(pre_hook) |
| 300 | + module.register_forward_hook(post_hook, always_call=True) |
| 301 | + |
127 | 302 | # This needs to be implemented in the subclass |
128 | 303 | def _load_model( |
129 | 304 | self, |
|
0 commit comments