Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions invokeai/backend/model_manager/model_on_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict:

path = self.resolve_weight_file(path)

if path in self._state_dict_cache:
return self._state_dict_cache[path]

with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
Expand Down
54 changes: 42 additions & 12 deletions invokeai/backend/quantization/gguf/loaders.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,52 @@
import gc
from pathlib import Path

import gguf
import torch

from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES
from invokeai.backend.util.logging import InvokeAILogger

logger = InvokeAILogger.get_logger()


class WrappedGGUFReader:
"""Wrapper around GGUFReader that adds a close() method."""

def __init__(self, path: Path):
self.reader = gguf.GGUFReader(path)

def __enter__(self):
return self.reader

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False

def close(self):
"""Explicitly close the memory-mapped file."""
if hasattr(self.reader, "data") and hasattr(self.reader.data, "_mmap"):
try:
self.reader.data._mmap.close()
except (AttributeError, OSError, ValueError) as e:
logger.warning(f"Wasn't able to close GGUF memory map: {e}")
del self.reader
gc.collect()


def gguf_sd_loader(path: Path, compute_dtype: torch.dtype) -> dict[str, GGMLTensor]:
reader = gguf.GGUFReader(path)

sd: dict[str, GGMLTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(
torch_tensor, ggml_quantization_type=tensor.tensor_type, tensor_shape=shape, compute_dtype=compute_dtype
)
return sd
with WrappedGGUFReader(path) as reader:
sd: dict[str, GGMLTensor] = {}
for tensor in reader.tensors:
torch_tensor = torch.from_numpy(tensor.data)
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
torch_tensor = torch_tensor.view(*shape)
sd[tensor.name] = GGMLTensor(
torch_tensor,
ggml_quantization_type=tensor.tensor_type,
tensor_shape=shape,
compute_dtype=compute_dtype,
)
return sd