Skip to content

Commit 6f42ad0

Browse files
feat: add per-model FP8 layerwise casting for VRAM reduction (#8945)
* feat: add per-model FP8 layerwise casting for VRAM reduction Add fp8_storage option to model default settings that enables diffusers' enable_layerwise_casting() to store weights in FP8 (float8_e4m3fn) while casting to fp16/bf16 during inference. This reduces VRAM usage by ~50% per model with minimal quality loss. Supported: SD1/SD2/SDXL/SD3, Flux, Flux2, CogView4, Z-Image, VAE (diffusers-based), ControlNet, T2IAdapter. Not applicable: Text Encoders, LoRA, GGUF, BnB, custom classes * feat: add FP8 storage option to Model Manager UI Add per-model FP8 storage toggle in Model Manager default settings for both main models and control adapter models. When enabled, model weights are stored in FP8 format in VRAM (~50% savings) and cast layer-by-layer to compute precision during inference via diffusers' enable_layerwise_casting(). Backend: add fp8_storage field to MainModelDefaultSettings and ControlAdapterDefaultSettings, apply FP8 layerwise casting in all relevant model loaders (SD, SDXL, FLUX, CogView4, Z-Image, ControlNet, T2IAdapter, VAE). Gracefully skips non-ModelMixin models (custom checkpoint loaders, GGUF, BnB). Frontend: add FP8 Storage switch to model default settings panels with InformationalPopover, translation keys, and proper form handling. * ruff format * fix: enable FP8 layerwise casting for checkpoint Flux models FluxCheckpointModel and Flux2CheckpointModel were missing the _apply_fp8_layerwise_casting call. Additionally, the FP8 casting only worked for diffusers ModelMixin models. Add manual layerwise casting via forward hooks for plain nn.Module (custom Flux class). Also simplify FP8 UI toggle from dual-slider to single switch, matching the CPU-only toggle pattern per review feedback on #8945. * fix: exclude Z-Image from FP8 due to diffusers layerwise casting bug Z-Image's transformer has dtype mismatches with diffusers' enable_layerwise_casting: skipped modules (t_embedder, cap_embedder) stay in bf16 while hooked modules cast to fp16, causing crashes in attention layers. Also hide the FP8 toggle in the UI for Z-Image models. * fix: detect model dtype for FP8 compute instead of using global dtype Models like Flux are loaded in bf16 but the global torch dtype is fp16, causing dtype mismatches during FP8 layerwise casting. Detect the model's actual parameter dtype and use it as compute_dtype for both diffusers ModelMixin and plain nn.Module models. * Remove call for _should_use_fp8 in z-image * Merge branch 'main' + exclude VAEs from FP8 layerwise casting Resolve merge conflict in vae.py by keeping upstream's Anima/QwenImage VAE loader paths and dropping the FP8 call from the AutoencoderKL checkpoint path. Exclude VAEs from FP8 layerwise casting in _should_use_fp8 (both standalone ModelType.VAE and the VAE/VAEDecoder/VAEEncoder submodel types of Main models). FP8 storage causes noticeable quality degradation on VAE decode. * fix(fp8): invalidate cache on settings change, exception-safe nn.Module fallback, hide ControlLoRA toggle - Add ModelCache.drop_model() and call it from update_model_record when fp8_storage or cpu_only change. These settings are baked into the loaded nn.Module at load time, so toggling them was silently a no-op until the cache entry was evicted by other means. - Replace the pre-hook/post-hook pair in _apply_fp8_to_nn_module with a forward wrapper using try/finally. register_forward_hook only fires on successful forward, so an exception left params in compute dtype and defeated the FP8 storage savings. - Hide the FP8 toggle in the UI for ControlLoRA and exclude LoRA/ControlLoRA in _should_use_fp8. LoRAs are patched into base models rather than run as a standalone forward pass, so layerwise-casting hooks would never fire. - Add tests for drop_model, the exception-safe FP8 wrapper, the ControlLoRA/LoRA exclusion, and the _load_settings_changed predicate. * fix(fp8): honor class swap for LoRA patches, evict stale locked entries, skip precision-sensitive layers - _wrap_forward_with_fp8_cast now dispatches via type(module).forward at call time instead of capturing the bound method. ModelCache.put() swaps nn.Linear.__class__ to CustomLinear (sharing __dict__), which would otherwise leave our instance forward shadowing CustomLinear.forward and silently bypass LoRA/ControlLoRA patch dispatch on FP8 checkpoints. - drop_model() now marks locked entries is_stale instead of skipping them silently; unlock() evicts stale entries once the last lock releases. Without this, a setting toggled during an in-flight generation survived on the locked entry and the next generation reused the pre-change module. - _apply_fp8_to_nn_module mirrors diffusers' apply_layerwise_casting: only the supported layer classes (Linear/Conv*/Embedding) get cast, and module paths matching norm/pos_embed/patch_embed/proj_in/proj_out are skipped. FLUX RMSNorm.scale and similar precision-sensitive scalars are no longer crushed to FP8. - drop_model() and the unlock-stale path now update stats.cleared and fire on_cache_models_cleared callbacks, matching _make_room_internal so the UI stats panel and observers don't miss invalidations. - Add 14 tests: class-swap dispatch, norm/pos_embed/proj_in_out skip, unsupported-type skip, stale-marking, multi-lock release, stats and callback firing for both paths, no-op silence. * fix(fp8): switch nn.Module FP8 wrapper to hooks so CustomLinear dispatch survives apply_custom_layers_to_model Previous fix was wrong. `apply_custom_layers_to_model` does not do `module.__class__ = CustomLinear` — `wrap_custom_layer` constructs a NEW CustomLinear via __new__ and shares the original Linear's __dict__, then setattr installs the new object on the parent. The new object has type() == CustomLinear, but our wrapped forward closed over the original Linear instance, so `type(module).forward(module, ...)` resolved to Linear.forward on the captured old object and silently bypassed CustomLinear.forward — breaking LoRA/ControlLoRA patch dispatch for FP8 checkpoint models. Reproduced on a fresh worktree. Replace the instance-forward override with register_forward_pre_hook + register_forward_hook(always_call=True). Hooks are dispatched by nn.Module._call_impl with the actual called instance, so they fire on the new CustomLinear and self.forward resolves normally via class lookup — reaching CustomLinear.forward and its patch-aware branch. always_call=True keeps the exception-safety guarantee (post-hook fires even when forward raises). Replace the simulated __class__-swap test with one that runs real apply_custom_layers_to_model, attaches a sentinel _patches_and_weights, and asserts the patch-aware branch in CustomLinear.forward is reached. Verified the test fails under the old instance-forward implementation with the reviewer-described symptom and passes under the hook fix. * Add docs for fp8 --------- Co-authored-by: Jonathan <[email protected]>
1 parent 0f937ce commit 6f42ad0

25 files changed

Lines changed: 1121 additions & 5 deletions

File tree

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
---
2+
title: FP8 Storage
3+
sidebar:
4+
order: 3
5+
---
6+
7+
import { Steps } from '@astrojs/starlight/components';
8+
9+
FP8 Storage cuts a model's VRAM footprint roughly in half by keeping weights on the GPU in 8-bit floating-point format (`float8_e4m3fn`). During inference, each layer's weights are cast on-the-fly back up to the compute precision (FP16/BF16), then cast back to FP8 after the forward pass — so quality is largely preserved.
10+
11+
It pairs well with [Low-VRAM mode](/configuration/low-vram-mode/): low-VRAM mode streams layers between RAM and VRAM, while FP8 Storage shrinks the layers themselves.
12+
13+
## Requirements
14+
15+
- **Nvidia GPU on Windows or Linux.** FP8 Storage uses CUDA tensor types and is silently disabled on CPU and MPS.
16+
- **CUDA 12.x and recent PyTorch.** The `float8_e4m3fn` dtype was added in PyTorch 2.1 — InvokeAI's bundled versions satisfy this.
17+
18+
There is no hardware requirement for FP8 *compute* — InvokeAI casts back to FP16/BF16 for math. This means FP8 Storage works on GPUs that do not natively support FP8 matmul (e.g. RTX 30-series), at a small per-step throughput cost.
19+
20+
## Enabling FP8 Storage
21+
22+
FP8 Storage is a **per-model setting**, configured from the Model Manager:
23+
24+
<Steps>
25+
1. Open the **Model Manager**.
26+
2. Select a model (Main, ControlNet, or T2I-Adapter).
27+
3. Under **Default Settings**, toggle **FP8 Storage (Save VRAM)**.
28+
4. Click **Save**.
29+
</Steps>
30+
31+
The setting takes effect on the next load. If the model is already in the cache, InvokeAI evicts the cached copy automatically so the new setting applies — even if a generation is currently using the model (the eviction is deferred until the generation finishes).
32+
33+
:::tip[When to enable]
34+
Enable FP8 Storage on large models that don't fit comfortably in VRAM — FLUX dev/Klein, large SDXL checkpoints, ControlNet-XL adapters. For smaller SD1 / SD2 models, the savings are negligible and not worth the small precision trade-off.
35+
:::
36+
37+
## What FP8 Storage applies to
38+
39+
FP8 Storage is **only** applied to layers where the precision trade-off is acceptable:
40+
41+
| Model type | FP8 applied? |
42+
| ----------------------------- | -------------------------------------- |
43+
| Main models (SD1, SD2, SDXL) | Yes |
44+
| FLUX.1 / FLUX.2 Klein | Yes |
45+
| ControlNet, T2I-Adapter | Yes |
46+
| VAE | No — visible decode-quality regression |
47+
| Text encoders, tokenizers | No — small models, no benefit |
48+
| Z-Image (any variant) | No — dtype mismatch with skipped layers|
49+
| LoRA, ControlLoRA | No — patched into base, not run alone |
50+
51+
Within a supported model, **norm layers, position/patch embeddings, and `proj_in`/`proj_out` are skipped** so precision-sensitive tiny learned scalars (e.g. FLUX `RMSNorm.scale`) aren't crushed to FP8. This mirrors the diffusers default skip list.
52+
53+
## Quality trade-offs
54+
55+
FP8 Storage is **near-lossless** for most workloads because:
56+
57+
- Norms and embeddings (the precision-sensitive layers) are skipped.
58+
- The actual matmul still happens in FP16/BF16 — FP8 is only the on-GPU storage format.
59+
60+
That said, some artifacts have been reported on:
61+
62+
- **VAEs** — never cast (the toggle has no effect on VAE submodels).
63+
- **Heavy LoRA stacks** — patching is unaffected, but very precision-sensitive LoRAs may show slight drift. Compare a side-by-side if your workflow depends on subtle LoRA behavior.
64+
65+
If you see unexpected quality regressions, disable FP8 Storage on the affected model and re-run.
66+
67+
## Combining with Low-VRAM mode and quantized models
68+
69+
- **FP8 + partial loading**: fully supported. FP8 Storage shrinks the layers; partial loading streams them between RAM and VRAM as needed. Use both on tight VRAM budgets.
70+
- **FP8 + GGUF / NF4 / int8 quantized checkpoints**: these formats already have their own storage precision. FP8 Storage is not applied on top — the toggle is silently a no-op for quantized formats, since the loader returns a different module type.
71+
72+
## Troubleshooting
73+
74+
### "I toggled FP8 Storage but VRAM usage didn't change"
75+
76+
The cache eviction is immediate for idle models, but **deferred until the next unlock** if the model is mid-generation. Wait for the current generation to finish, then start a new one — the next load will use the new setting.
77+
78+
If VRAM still hasn't dropped:
79+
80+
- Check the InvokeAI log for `FP8 layerwise casting enabled for <model name>`. If the line isn't there, the model is on the exclusion list (VAE, text encoder, Z-Image, LoRA — see table above).
81+
- Confirm you are on CUDA. FP8 Storage is silently disabled on CPU and MPS.
82+
83+
### Quality regression on a specific model
84+
85+
Disable FP8 Storage for that model in Model Manager and reload. If quality is restored, the model has FP8-sensitive layers that fall outside the default skip list. Please open an issue with the model name and a side-by-side comparison.
86+
87+
### "RuntimeError: ... float8_e4m3fn ..."
88+
89+
You're on a PyTorch version that predates FP8 support. Reinstall InvokeAI using the official launcher — the bundled torch version supports FP8.

invokeai/app/api/routers/model_manager.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,17 @@ async def update_model_record(
437437
logger = ApiDependencies.invoker.services.logger
438438
record_store = ApiDependencies.invoker.services.model_manager.store
439439
try:
440+
previous_config = record_store.get_model(key)
440441
config = record_store.update_model(key, changes=changes, allow_class_change=True)
442+
# Settings that change how the model loads (e.g. fp8_storage, cpu_only) are baked into the cached
443+
# nn.Module at load time, so toggling them on a cached model is otherwise silently a no-op until
444+
# the entry is evicted. Drop any unlocked cached entries for this model so the next load rebuilds.
445+
if _load_settings_changed(previous_config, config):
446+
dropped = ApiDependencies.invoker.services.model_manager.load.ram_cache.drop_model(key)
447+
if dropped:
448+
logger.info(
449+
f"Dropped {dropped} cached entr{'y' if dropped == 1 else 'ies'} for model {key} after settings change."
450+
)
441451
config = prepare_model_config_for_response(config, ApiDependencies)
442452
logger.info(f"Updated model: {key}")
443453
except UnknownModelException as e:
@@ -448,6 +458,26 @@ async def update_model_record(
448458
return config
449459

450460

461+
_LOAD_AFFECTING_SETTINGS: tuple[str, ...] = ("fp8_storage", "cpu_only")
462+
463+
464+
def _load_settings_changed(previous: AnyModelConfig, updated: AnyModelConfig) -> bool:
465+
"""Return True if any setting that influences how the model is loaded changed.
466+
467+
Such settings are read by the loader during `_load_model` and baked into the resulting
468+
nn.Module, so a cached entry built under the old value must be evicted for the change
469+
to take effect.
470+
"""
471+
if getattr(previous, "cpu_only", None) != getattr(updated, "cpu_only", None):
472+
return True
473+
previous_settings = getattr(previous, "default_settings", None)
474+
updated_settings = getattr(updated, "default_settings", None)
475+
for field in _LOAD_AFFECTING_SETTINGS:
476+
if getattr(previous_settings, field, None) != getattr(updated_settings, field, None):
477+
return True
478+
return False
479+
480+
451481
@model_manager_router.get(
452482
"/i/{key}/image",
453483
operation_id="get_model_image",

invokeai/backend/model_manager/configs/controlnet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
class ControlAdapterDefaultSettings(BaseModel):
5555
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
5656
preprocessor: str | None
57+
fp8_storage: bool | None = Field(
58+
default=None,
59+
description="Store weights in FP8 to reduce VRAM usage (~50% savings). Weights are cast to compute dtype during inference.",
60+
)
5761
model_config = ConfigDict(extra="forbid")
5862

5963
@classmethod

invokeai/backend/model_manager/configs/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ class MainModelDefaultSettings(BaseModel):
5252
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
5353
guidance: float | None = Field(default=None, ge=1, description="Default Guidance for this model")
5454
cpu_only: bool | None = Field(default=None, description="Whether this model should run on CPU only")
55+
fp8_storage: bool | None = Field(
56+
default=None,
57+
description="Store weights in FP8 to reduce VRAM usage (~50% savings). Weights are cast to compute dtype during inference.",
58+
)
5559

5660
model_config = ConfigDict(extra="forbid")
5761

invokeai/backend/model_manager/load/load_default.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
22
"""Default implementation of model loading in InvokeAI."""
33

4+
import re
45
from logging import Logger
56
from pathlib import Path
67
from typing import Optional
@@ -21,6 +22,35 @@
2122
)
2223
from invokeai.backend.util.devices import TorchDevice
2324

25+
# Layer classes that benefit from FP8 storage. Mirrors diffusers'
26+
# `_GO_LC_SUPPORTED_PYTORCH_LAYERS` so the plain-nn.Module fallback path makes the same
27+
# precision/quality trade-offs as the ModelMixin path. Notably excludes norm and embedding
28+
# wrapper modules — those are handled by their direct param types (Embedding is included
29+
# but pos_embed/patch_embed are filtered by `_FP8_DEFAULT_SKIP_PATTERNS`).
30+
_FP8_SUPPORTED_PYTORCH_LAYERS: tuple[type[torch.nn.Module], ...] = (
31+
torch.nn.Linear,
32+
torch.nn.Conv1d,
33+
torch.nn.Conv2d,
34+
torch.nn.Conv3d,
35+
torch.nn.ConvTranspose1d,
36+
torch.nn.ConvTranspose2d,
37+
torch.nn.ConvTranspose3d,
38+
torch.nn.Embedding,
39+
)
40+
41+
# Module-path regexes (matched against `named_modules()` dotted paths) for precision-sensitive
42+
# layers that should never be cast to FP8. Mirrors diffusers' `DEFAULT_SKIP_MODULES_PATTERN`
43+
# — without these, FLUX RMSNorm.scale and similar tiny learned scalars get crushed to FP8 and
44+
# inference quality degrades. Includes anything named `norm`, position/patch embeddings, and
45+
# the in/out projection of transformer blocks.
46+
_FP8_DEFAULT_SKIP_PATTERNS: tuple[str, ...] = (
47+
"pos_embed",
48+
"patch_embed",
49+
"norm",
50+
r"^proj_in$",
51+
r"^proj_out$",
52+
)
53+
2454

2555
# TO DO: The loader is not thread safe!
2656
class ModelLoader(ModelLoaderBase):
@@ -124,6 +154,151 @@ def get_size_fs(
124154
variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None,
125155
)
126156

157+
def _should_use_fp8(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> bool:
158+
"""Check if FP8 layerwise casting should be applied to a model."""
159+
# FP8 storage only works on CUDA
160+
if self._torch_device.type != "cuda":
161+
return False
162+
163+
# Z-Image has dtype mismatch issues with diffusers' layerwise casting
164+
# (skipped modules produce bf16, hooked modules expect fp16).
165+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
166+
167+
if hasattr(config, "base") and config.base == BaseModelType.ZImage:
168+
return False
169+
170+
# VAEs are excluded — fp8 storage causes noticeable quality degradation in decode.
171+
if hasattr(config, "type") and config.type == ModelType.VAE:
172+
return False
173+
174+
# LoRAs (including ControlLoRA) are excluded — they are not run as a standalone forward pass,
175+
# they are patched into a base model, so the layerwise-casting hooks would never fire. The
176+
# toggle is also hidden in the UI for ControlLoRA; this guard handles legacy persisted values.
177+
if hasattr(config, "type") and config.type in (ModelType.LoRA, ModelType.ControlLoRa):
178+
return False
179+
180+
# Don't apply FP8 to text encoders, tokenizers, schedulers, VAEs, etc.
181+
_excluded_submodel_types = {
182+
SubModelType.TextEncoder,
183+
SubModelType.TextEncoder2,
184+
SubModelType.TextEncoder3,
185+
SubModelType.Tokenizer,
186+
SubModelType.Tokenizer2,
187+
SubModelType.Tokenizer3,
188+
SubModelType.Scheduler,
189+
SubModelType.SafetyChecker,
190+
SubModelType.VAE,
191+
SubModelType.VAEDecoder,
192+
SubModelType.VAEEncoder,
193+
}
194+
if submodel_type in _excluded_submodel_types:
195+
return False
196+
197+
# Check default_settings.fp8_storage (Main models, ControlNet)
198+
if hasattr(config, "default_settings") and config.default_settings is not None:
199+
if hasattr(config.default_settings, "fp8_storage") and config.default_settings.fp8_storage is True:
200+
return True
201+
202+
return False
203+
204+
def _apply_fp8_layerwise_casting(
205+
self, model: AnyModel, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
206+
) -> AnyModel:
207+
"""Apply FP8 layerwise casting to a model if enabled in its config."""
208+
if not self._should_use_fp8(config, submodel_type):
209+
return model
210+
211+
storage_dtype = torch.float8_e4m3fn
212+
compute_dtype = self._torch_dtype
213+
214+
# Detect the model's current dtype to use as compute dtype, since models
215+
# (e.g. Flux) may require a specific dtype (bf16) that differs from the global torch dtype (fp16).
216+
if isinstance(model, torch.nn.Module):
217+
first_param = next(model.parameters(), None)
218+
if first_param is not None:
219+
compute_dtype = first_param.dtype
220+
221+
from diffusers.models.modeling_utils import ModelMixin
222+
223+
if isinstance(model, ModelMixin):
224+
model.enable_layerwise_casting(
225+
storage_dtype=storage_dtype,
226+
compute_dtype=compute_dtype,
227+
)
228+
elif isinstance(model, torch.nn.Module):
229+
self._apply_fp8_to_nn_module(model, storage_dtype=storage_dtype, compute_dtype=compute_dtype)
230+
else:
231+
return model
232+
233+
param_bytes = sum(p.nelement() * p.element_size() for p in model.parameters())
234+
self._logger.info(
235+
f"FP8 layerwise casting enabled for {config.name} "
236+
f"(storage=float8_e4m3fn, compute={compute_dtype}, "
237+
f"param_size={param_bytes / (1024**2):.0f}MB)"
238+
)
239+
return model
240+
241+
@staticmethod
242+
def _apply_fp8_to_nn_module(model: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype) -> None:
243+
"""Apply FP8 layerwise casting to a plain nn.Module.
244+
245+
Mirrors diffusers' `apply_layerwise_casting` semantics: only the layer classes in
246+
`_FP8_SUPPORTED_PYTORCH_LAYERS` are cast, and modules whose dotted path matches any of
247+
`_FP8_DEFAULT_SKIP_PATTERNS` (norm, pos_embed, patch_embed, proj_in/out) are skipped.
248+
Without the skip list, precision-sensitive tiny learned scalars (e.g. FLUX RMSNorm.scale)
249+
get crushed to FP8 and quality degrades noticeably.
250+
"""
251+
for module_name, module in model.named_modules():
252+
if not isinstance(module, _FP8_SUPPORTED_PYTORCH_LAYERS):
253+
continue
254+
if any(re.search(pattern, module_name) for pattern in _FP8_DEFAULT_SKIP_PATTERNS):
255+
continue
256+
params = list(module.parameters(recurse=False))
257+
if not params:
258+
continue
259+
260+
for param in params:
261+
param.data = param.data.to(storage_dtype)
262+
263+
ModelLoader._wrap_forward_with_fp8_cast(module, storage_dtype, compute_dtype)
264+
265+
@staticmethod
266+
def _wrap_forward_with_fp8_cast(
267+
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype
268+
) -> None:
269+
"""Register pre/post forward hooks that cast params to compute dtype on entry and back
270+
to storage dtype on exit.
271+
272+
We use hooks (rather than overriding `module.forward`) for two reasons:
273+
274+
1. **Correct dispatch after `apply_custom_layers_to_model`.** `ModelCache.put()` calls
275+
`apply_custom_layers_to_model`, which creates a NEW `CustomLinear` instance and
276+
shares the original `Linear.__dict__` (see `wrap_custom_layer`). Anything stored in
277+
that dict — including an instance-level `forward` attribute — gets carried over to
278+
the new object. An overridden `forward` would close over the OLD instance, so calls
279+
to the new `CustomLinear` would silently route to `Linear.forward(old_instance, ...)`
280+
and bypass the LoRA-patch-aware branch in `CustomLinear.forward`. Hooks, by contrast,
281+
live in `_forward_hooks` / `_forward_pre_hooks` and are dispatched by
282+
`nn.Module.__call__` with the *actual* called instance — so they run on the new
283+
`CustomLinear` and the class's `forward` is still resolved normally.
284+
285+
2. **Exception safety.** `register_forward_hook(..., always_call=True)` fires the
286+
post-hook even when `forward` raises. The plain pre-hook/post-hook pair without
287+
`always_call` would leave params in compute dtype on exception, defeating FP8
288+
storage savings and making cache size accounting stale.
289+
"""
290+
291+
def pre_hook(mod: torch.nn.Module, _args: object) -> None:
292+
for p in mod.parameters(recurse=False):
293+
p.data = p.data.to(compute_dtype)
294+
295+
def post_hook(mod: torch.nn.Module, _args: object, _output: object) -> None:
296+
for p in mod.parameters(recurse=False):
297+
p.data = p.data.to(storage_dtype)
298+
299+
module.register_forward_pre_hook(pre_hook)
300+
module.register_forward_hook(post_hook, always_call=True)
301+
127302
# This needs to be implemented in the subclass
128303
def _load_model(
129304
self,

invokeai/backend/model_manager/load/model_cache/cache_record.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ class CacheRecord:
1717
# Model in memory.
1818
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
1919
_locks: int = 0
20+
# Set by ModelCache.drop_model() when the entry was locked at invalidation time.
21+
# ModelCache.unlock() evicts the entry as soon as the last lock releases so a setting
22+
# change (e.g. fp8_storage toggled during an in-flight generation) takes effect on the
23+
# next load instead of silently being ignored.
24+
is_stale: bool = False
2025

2126
def lock(self) -> None:
2227
"""Lock this record."""

0 commit comments

Comments
 (0)