Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/test_bucket_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# tests/test_bucket_table.py
from diffusers_helper.bucket_tools import bucket_options


def test_bucket_options_nonempty():
"""bucket_options should have integer keys and non-empty lists of (h, w) tuples."""
assert isinstance(bucket_options, dict)
assert all(isinstance(k, int) for k in bucket_options)
assert any(bucket_options.values()) # at least one non-empty list
for key, lst in bucket_options.items():
for h, w in lst:
assert isinstance(h, int)
assert isinstance(w, int)
assert h > 0 and w > 0
15 changes: 15 additions & 0 deletions tests/test_bucket_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from diffusers_helper.bucket_tools import find_nearest_bucket, bucket_options


def test_returns_known_bucket_and_shape():
"""If input matches one of the bucket options exactly, it should return it."""
h, w = 480, 832
bh, bw = find_nearest_bucket(h, w, resolution=640)
assert (bh, bw) in bucket_options[640]
assert (bh, bw) == (480, 832)


def test_picks_minimum_area_difference():
"""If input does not match exactly, pick the closest area bucket."""
bh, bw = find_nearest_bucket(500, 800, resolution=640)
assert (bh, bw) == (512, 768)
31 changes: 31 additions & 0 deletions tests/test_clip_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
from diffusers_helper.clip_vision import hf_clip_vision_encode


class _DummyBatch(dict):
def to(self, **kwargs):
return self


class _DummyExtractor:
def preprocess(self, images, return_tensors="pt"):
assert return_tensors == "pt"
assert isinstance(images, np.ndarray)
return _DummyBatch({"pixel_values": "ok"})


class _DummyEncoder:
device = "cpu"
dtype = "float32"

def __call__(self, **kwargs):
# Return a dummy output to simulate a successful encode
return {"last_hidden_state": "dummy"}


def test_hf_clip_vision_encode_smoke():
"""Smoke test: hf_clip_vision_encode should call extractor and encoder successfully."""
img = np.zeros((8, 8, 3), dtype=np.uint8)
out = hf_clip_vision_encode(img, _DummyExtractor(), _DummyEncoder())
assert isinstance(out, dict)
assert "last_hidden_state" in out
40 changes: 40 additions & 0 deletions tests/test_hf_login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# tests/test_hf_login.py
import importlib
import types


def test_login_retries_then_succeeds(monkeypatch, capsys):
"""login() should retry on failure and eventually succeed (no autologin on import)."""
calls = {"n": 0}

def fake_login(_token):
calls["n"] += 1
if calls["n"] < 3:
raise RuntimeError("temporary HF error")
# success on third attempt

# Ensure no autologin during import
monkeypatch.delenv("HF_TOKEN", raising=False)

# Make sure we import a fresh module
if "diffusers_helper.hf_login" in importlib.sys.modules:
del importlib.sys.modules["diffusers_helper.hf_login"]

# Provide fake huggingface_hub before import
monkeypatch.setitem(
importlib.import_module("sys").modules,
"huggingface_hub",
types.SimpleNamespace(login=fake_login),
)

# Avoid real sleeps
monkeypatch.setattr("time.sleep", lambda *_a, **_k: None)

# Import module (no autologin because HF_TOKEN is unset)
from diffusers_helper import hf_login

# Now call the function explicitly (this will do the retries)
hf_login.login("abc")
captured = capsys.readouterr().out
assert calls["n"] == 3
assert "HF login ok." in captured
95 changes: 95 additions & 0 deletions tests/test_memory_dynamic_swap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import sys, types, importlib


def _import_memory_with_fake_torch():
"""Import diffusers_helper.memory with a minimal fake torch API."""
fake_torch = types.ModuleType("torch")

# torch.device -> echo string/device id
fake_torch.device = lambda x: str(x)

# Minimal Parameter
class Parameter:
def __init__(self, data=None, requires_grad=False):
self.data = data
self.requires_grad = requires_grad

def to(self, **kwargs):
return self.data

# Minimal Module
class Module:
def __init__(self):
self._parameters = {}

def modules(self):
return []

# torch.nn namespace
class NN:
pass

NN.Parameter = Parameter
NN.Module = Module
fake_torch.nn = NN()

# torch.cuda stub
class _CUDA:
@staticmethod
def current_device():
return 0

@staticmethod
def memory_stats(device):
return {"active_bytes.all.current": 0, "reserved_bytes.all.current": 0}

@staticmethod
def mem_get_info(device):
return (0, 0)

@staticmethod
def empty_cache():
return None

fake_torch.cuda = _CUDA()

sys.modules["torch"] = fake_torch

if "diffusers_helper.memory" in sys.modules:
del sys.modules["diffusers_helper.memory"]
return importlib.import_module("diffusers_helper.memory")


def test_dynamic_swap_install_and_uninstall():
"""Installing should wrap __getattr__ to cast parameters; uninstall should restore the class."""
memory = _import_memory_with_fake_torch()

# A tiny fake "module" with _parameters and a .to() method
class Layer:
def __init__(self):
self._parameters = {
"weight": memory.torch.nn.Parameter(data="W", requires_grad=True)
}

def to(self, **kwargs):
self.last_to = kwargs
return self

# modules() is used by DynamicSwapInstaller.install_model
def modules(self):
yield self

layer = Layer()
original_cls = layer.__class__

# Install swapper and access a parameter to trigger __getattr__
memory.DynamicSwapInstaller.install_model(layer, dtype="float16")
assert layer.__class__ is not original_cls # class is replaced

wrapped = layer.weight # __getattr__ should create a new Parameter with casted data
assert isinstance(wrapped, memory.torch.nn.Parameter)
assert wrapped.requires_grad is True
# After uninstall, class should be restored
memory.DynamicSwapInstaller.uninstall_model(layer)
assert layer.__class__ is original_cls
assert "forge_backup_original_class" not in layer.__dict__