diff --git a/setup.py b/setup.py index e6f7ab788..efaffaa76 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ import shutil import subprocess import sys +from pathlib import Path from setuptools import find_packages, setup @@ -155,6 +156,25 @@ def get_platform_dependencies(): return deps +def _collect_package_files(*directories: str): + """Collect package data files relative to the simpletuner package.""" + collected = [] + package_root = Path("simpletuner") + for directory in directories: + root = Path(directory) + if not root.exists(): + continue + for path in root.rglob("*"): + if path.is_file(): + try: + relative = path.relative_to(package_root) + except ValueError: + # Skip files outside package root + continue + collected.append(str(relative)) + return collected + + # Base dependencies (minimal, works on all platforms) base_deps = [ "diffusers>=0.35.1", @@ -236,6 +256,15 @@ def get_platform_dependencies(): author="bghira", # license handled by pyproject.toml packages=find_packages(), + include_package_data=True, + package_data={ + "simpletuner": _collect_package_files( + "simpletuner/templates", + "simpletuner/static", + "simpletuner/config", + "simpletuner/documentation", + ), + }, python_requires=">=3.11,<3.14", install_requires=base_deps + platform_deps_for_install, extras_require=extras_require, diff --git a/simpletuner/helpers/models/qwen_image/model.py b/simpletuner/helpers/models/qwen_image/model.py index 51ebd0b02..a77af23d6 100644 --- a/simpletuner/helpers/models/qwen_image/model.py +++ b/simpletuner/helpers/models/qwen_image/model.py @@ -214,12 +214,13 @@ def model_predict(self, prepared_batch): return_dict=False, )[0] - # unpack noise prediction - noise_pred = pipeline_class._unpack_latents(noise_pred, pixel_height, pixel_width, self.vae_scale_factor) + # unpack noise prediction if the transformer returned packed tokens + if noise_pred.dim() == 3: + noise_pred = pipeline_class._unpack_latents(noise_pred, pixel_height, pixel_width, self.vae_scale_factor) - # remove extra dimension from _unpack_latents - if noise_pred.dim() == 5: - noise_pred = noise_pred.squeeze(2) # Remove the frame dimension + # remove extra dimension from _unpack_latents + if noise_pred.dim() == 5: + noise_pred = noise_pred.squeeze(2) # Remove the frame dimension return {"model_prediction": noise_pred} diff --git a/simpletuner/helpers/training/quantisation/quanto_workarounds.py b/simpletuner/helpers/training/quantisation/quanto_workarounds.py index ca60b1daa..0c98cb8ef 100644 --- a/simpletuner/helpers/training/quantisation/quanto_workarounds.py +++ b/simpletuner/helpers/training/quantisation/quanto_workarounds.py @@ -4,37 +4,38 @@ if torch.cuda.is_available(): # the marlin fp8 kernel needs some help with dtype casting for some reason # see: https://github.com/huggingface/optimum-quanto/pull/296#issuecomment-2380719201 - from optimum.quanto.library.extensions.cuda import ext as quanto_ext + if torch.device("cuda").type == "cuda" and torch.version.cuda: + from optimum.quanto.library.extensions.cuda import ext as quanto_ext - # Save the original operator - original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin + # Save the original operator + original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin - def fp8_marlin_gemm_wrapper( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, - ) -> torch.Tensor: - # Ensure 'a' has the correct dtype - a = a.to(b_scales.dtype) - # Call the original operator - return original_gemm_f16f8_marlin( - a, - b_q_weight, - b_scales, - workspace, - num_bits, - size_m, - size_n, - size_k, - ) + def fp8_marlin_gemm_wrapper( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, + ) -> torch.Tensor: + # Ensure 'a' has the correct dtype + a = a.to(b_scales.dtype) + # Call the original operator + return original_gemm_f16f8_marlin( + a, + b_q_weight, + b_scales, + workspace, + num_bits, + size_m, + size_n, + size_k, + ) - # Monkey-patch the operator - torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper + # Monkey-patch the operator + torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper class TinyGemmQBitsLinearFunction(optimum.quanto.tensor.function.QuantizedLinearFunction): @staticmethod diff --git a/simpletuner/simpletuner_sdk/server/services/system_status_service.py b/simpletuner/simpletuner_sdk/server/services/system_status_service.py index 4c40f5b32..ffcd1d7af 100644 --- a/simpletuner/simpletuner_sdk/server/services/system_status_service.py +++ b/simpletuner/simpletuner_sdk/server/services/system_status_service.py @@ -77,16 +77,19 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]: devices = (inventory or {}).get("devices") or [] results: List[Dict[str, Any]] = [] mac_utilisation: Optional[List[Optional[float]]] = None - nvidia_fallback: Optional[List[Optional[float]]] = None + mac_memory: Optional[List[Optional[float]]] = None + nvidia_fallback: Optional[List[Dict[str, Optional[float]]]] = None rocm_fallback: Optional[List[Optional[float]]] = None if backend == "mps": mac_utilisation = self._get_macos_gpu_utilisation() + mac_memory = self._get_mps_memory_percent() for position, device in enumerate(devices): index = device.get("index") name = device.get("name") or f"GPU {index if index is not None else '?'}" utilisation: Optional[float] = None + memory_percent: Optional[float] = None if backend == "cuda" and index is not None and torch is not None and hasattr(torch.cuda, "utilization"): try: @@ -95,6 +98,7 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]: except Exception: logger.debug("Failed to read CUDA utilisation for device %s", index, exc_info=True) utilisation = None + memory_percent = self._get_cuda_memory_percent(index) if utilisation is None and backend == "mps" and mac_utilisation: target_idx: Optional[int] = None if isinstance(index, int) and 0 <= index < len(mac_utilisation): @@ -103,16 +107,29 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]: target_idx = position if target_idx is not None: utilisation = mac_utilisation[target_idx] + if mac_memory: + mem_idx: Optional[int] = None + if isinstance(index, int) and 0 <= index < len(mac_memory): + mem_idx = index + elif 0 <= position < len(mac_memory): + mem_idx = position + if mem_idx is not None: + memory_percent = mac_memory[mem_idx] if utilisation is None and backend == "cuda": if nvidia_fallback is None: - nvidia_fallback = self._get_nvidia_gpu_utilisation() + nvidia_fallback = self._get_nvidia_gpu_stats() if nvidia_fallback: - target_idx = None + target_idx: Optional[int] = None if isinstance(index, int) and 0 <= index < len(nvidia_fallback): target_idx = index elif 0 <= position < len(nvidia_fallback): target_idx = position if target_idx is not None: + fallback_entry = nvidia_fallback[target_idx] + if utilisation is None: + utilisation = fallback_entry.get("utilization_percent") + if memory_percent is None: + memory_percent = fallback_entry.get("memory_percent") utilisation = nvidia_fallback[target_idx] if utilisation is None and backend == "rocm": if rocm_fallback is None: @@ -132,6 +149,7 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]: "name": name, "backend": backend, "utilization_percent": round(utilisation, 1) if utilisation is not None else None, + "memory_percent": round(memory_percent, 1) if memory_percent is not None else None, } ) @@ -442,7 +460,7 @@ def _get_nvidia_gpu_utilisation(self) -> Optional[List[Optional[float]]]: completed = subprocess.run( [ "nvidia-smi", - "--query-gpu=utilization.gpu", + "--query-gpu=utilization.gpu,memory.used,memory.total", "--format=csv,noheader,nounits", ], check=True, @@ -458,19 +476,82 @@ def _get_nvidia_gpu_utilisation(self) -> Optional[List[Optional[float]]]: if not lines: return None - utilisation_values: List[Optional[float]] = [] + stats: List[Dict[str, Optional[float]]] = [] for line in lines: text = line.strip() if not text: - utilisation_values.append(None) + stats.append({"utilization_percent": None, "memory_percent": None}) + continue + parts = [part.strip() for part in text.split(",")] + if len(parts) < 3: + logger.debug("Discarding unexpected nvidia-smi output: %s", text) + stats.append({"utilization_percent": None, "memory_percent": None}) continue + util_raw, mem_used_raw, mem_total_raw = parts[:3] + util_val: Optional[float] + mem_percent: Optional[float] try: - utilisation_values.append(round(float(text), 1)) + util_val = round(float(util_raw), 1) except ValueError: - logger.debug("Discarding unexpected nvidia-smi output: %s", text) - utilisation_values.append(None) + util_val = None + try: + mem_used = float(mem_used_raw) + mem_total = float(mem_total_raw) + if mem_total > 0: + mem_percent = round((mem_used / mem_total) * 100.0, 1) + else: + mem_percent = None + except ValueError: + mem_percent = None + stats.append({"utilization_percent": util_val, "memory_percent": mem_percent}) - return utilisation_values or None + return stats or None + + def _get_cuda_memory_percent(self, index: int) -> Optional[float]: + if torch is None or not torch.cuda.is_available(): # type: ignore[attr-defined] + return None + if not hasattr(torch.cuda, "mem_get_info"): + return None + try: + try: + free_bytes, total_bytes = torch.cuda.mem_get_info(index) # type: ignore[misc] + except TypeError: + with torch.cuda.device(index): + free_bytes, total_bytes = torch.cuda.mem_get_info() # type: ignore[call-arg] + except Exception: + logger.debug("Failed to read CUDA memory info for device %s", index, exc_info=True) + return None + if not total_bytes: + return None + used = total_bytes - free_bytes + if used < 0: + used = 0 + try: + percent = (used / total_bytes) * 100.0 + except Exception: + return None + return round(float(percent), 1) + + def _get_mps_memory_percent(self) -> Optional[List[Optional[float]]]: + if torch is None: + return None + backend = getattr(torch.backends, "mps", None) + if backend is None or not backend.is_available(): + return None + driver_alloc = getattr(torch.mps, "driver_allocated_memory", None) + driver_total = getattr(torch.mps, "driver_total_memory", None) + if not callable(driver_alloc) or not callable(driver_total): + return None + try: + allocated = float(driver_alloc()) + total = float(driver_total()) + except Exception: + logger.debug("Unable to query MPS memory statistics", exc_info=True) + return None + if total <= 0: + return None + percent = round((allocated / total) * 100.0, 1) + return [percent] __all__ = ["SystemStatusService"] diff --git a/simpletuner/templates/trainer_htmx.html b/simpletuner/templates/trainer_htmx.html index 6e5ba422e..deaa999f9 100644 --- a/simpletuner/templates/trainer_htmx.html +++ b/simpletuner/templates/trainer_htmx.html @@ -2049,11 +2049,16 @@ typeof gpu?.utilization_percent === 'number' && Number.isFinite(gpu.utilization_percent) ? Number(gpu.utilization_percent) : null; + const memoryPercent = + typeof gpu?.memory_percent === 'number' && Number.isFinite(gpu.memory_percent) + ? Number(gpu.memory_percent) + : null; return { index: idx, name, backend: gpu?.backend || data.backend || null, utilization_percent: utilization, + memory_percent: memoryPercent, }; }) : []; @@ -2129,16 +2134,38 @@ const gpus = Array.isArray(this.systemStatus.gpus) ? this.systemStatus.gpus : []; gpus.forEach((gpu, index) => { const util = gpu.utilization_percent; + const mem = gpu.memory_percent; const trimmedName = typeof gpu.name === 'string' ? gpu.name.trim() : ''; const displayLabel = trimmedName || (typeof gpu.index === 'number' ? `GPU${gpu.index}` : `GPU${index}`); + const valueParts = []; + if (util !== null) { + valueParts.push(`${Math.round(util)}%`); + } + if (mem !== null) { + valueParts.push(`${Math.round(mem)}% mem`); + } + const formattedDisplay = valueParts.length > 0 ? valueParts.join(' · ') : '—'; + const tooltipParts = []; + if (trimmedName) { + tooltipParts.push(trimmedName); + } else { + tooltipParts.push(displayLabel); + } + if (util !== null) { + tooltipParts.push(`Util: ${util.toFixed(1)}%`); + } + if (mem !== null) { + tooltipParts.push(`Mem: ${mem.toFixed(1)}%`); + } items.push({ key: `gpu-${gpu.index ?? index}`, label: displayLabel, value: util, - formatted: util !== null ? `${Math.round(util)}%` : '—', - tooltip: trimmedName || displayLabel, + memory: mem, + formatted: formattedDisplay, + tooltip: tooltipParts.join(' • '), }); }); const offloadInfo = this.systemStatus.deepspeed_offload; diff --git a/tests/test_transformer_integration.py b/tests/test_transformer_integration.py index eb0003f7c..11a760d9f 100644 --- a/tests/test_transformer_integration.py +++ b/tests/test_transformer_integration.py @@ -73,6 +73,7 @@ def test_all_test_files_discovered(self): "hidream", "auraflow", "chroma", + "chroma_controlnet", "cosmos", "sd3", "pixart", @@ -95,7 +96,7 @@ def test_all_test_files_discovered(self): self.assertEqual( len(self.test_files), len(expected_transformers), - f"Expected {len(expected_transformers)} test files, found {len(self.test_files)}", + f"Expected {len(expected_transformers)} test files, found {len(self.test_files)}: {self.test_files}", ) def test_base_test_class_inheritance(self): diff --git a/tests/test_transformers/test_chroma_transformer.py b/tests/test_transformers/test_chroma_transformer.py new file mode 100644 index 000000000..ae8c45521 --- /dev/null +++ b/tests/test_transformers/test_chroma_transformer.py @@ -0,0 +1,129 @@ +""" +Focused unit tests for ChromaTransformer2DModel. + +These tests follow the shared transformer test conventions around: +- Base class inheritance from TransformerBaseTest +- Usage of helper utilities (TensorGenerator, MockDiffusersConfig, etc.) +- Standard method naming patterns for instantiation, forward pass, and typo prevention +""" + +import os +import sys +import unittest +from typing import Dict, Optional + +import torch + +# Make shared transformer test utilities importable +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils")) + +from transformer_base_test import TransformerBaseTest # noqa: E402 +from transformer_test_helpers import ( # noqa: E402 + MockDiffusersConfig, + ShapeValidator, + TensorGenerator, + TypoTestUtils, +) + +from simpletuner.helpers.models.chroma.transformer import ChromaTransformer2DModel + + +def _build_minimal_config(overrides: Optional[Dict] = None) -> Dict: + """Return a tiny configuration that keeps tests fast.""" + config = { + "patch_size": 1, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 4, + "num_attention_heads": 2, + "joint_attention_dim": 8, + "axes_dims_rope": (2, 2, 2), + "approximator_num_channels": 8, + "approximator_hidden_dim": 16, + "approximator_layers": 1, + } + if overrides: + config.update(overrides) + return config + + +class TestChromaTransformer2DModel(TransformerBaseTest): + """Test suite covering core behaviour of ChromaTransformer2DModel.""" + + def setUp(self): + super().setUp() + self.tensor_gen = TensorGenerator() + self.shape_validator = ShapeValidator() + self.typo_utils = TypoTestUtils() + self.default_config = _build_minimal_config() + + def _build_transformer(self, overrides: Optional[Dict] = None) -> ChromaTransformer2DModel: + return ChromaTransformer2DModel(**_build_minimal_config(overrides)) + + @torch.no_grad() + def test_basic_instantiation(self): + """Ensure the transformer can be constructed and exposes expected config.""" + transformer = self._build_transformer() + + self.assertEqual(transformer.inner_dim, self.default_config["num_attention_heads"] * self.default_config["attention_head_dim"]) + self.assertEqual(transformer.config.patch_size, self.default_config["patch_size"]) + self.assertEqual(transformer.config.num_layers, self.default_config["num_layers"]) + + mock_config = MockDiffusersConfig( + num_attention_heads=transformer.config.num_attention_heads, + attention_head_dim=transformer.config.attention_head_dim, + patch_size=transformer.config.patch_size, + num_layers=transformer.config.num_layers, + ) + self.assertEqual(mock_config.num_layers, self.default_config["num_layers"]) + + self.typo_utils.test_method_name_existence(transformer, ["forward"]) + + @torch.no_grad() + def test_forward_pass_structure(self): + """Run a minimal forward pass and validate shape semantics.""" + transformer = self._build_transformer() + + batch_size = 2 + hidden_states = self.tensor_gen.create_hidden_states(batch_size=batch_size, seq_len=8, hidden_dim=transformer.config.in_channels) + encoder_hidden_states = self.tensor_gen.create_encoder_hidden_states( + batch_size=batch_size, seq_len=3, hidden_dim=self.default_config["joint_attention_dim"] + ) + timestep = self.tensor_gen.create_timestep(batch_size=batch_size) + txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3) + img_ids = torch.zeros(hidden_states.shape[1], 3) + + outputs = transformer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + img_ids=img_ids, + txt_ids=txt_ids, + return_dict=False, + ) + + self.assertIsInstance(outputs, tuple) + self.assertEqual(len(outputs), 1) + sample = outputs[0] + + expected_hidden = hidden_states.shape[1] + expected_channels = transformer.config.patch_size * transformer.config.patch_size * transformer.config.out_channels + self.assertEqual(sample.shape, (batch_size, expected_hidden, expected_channels)) + self.shape_validator.validate_transformer_output(sample, batch_size, expected_hidden, expected_channels) + self.assert_no_nan_or_inf(sample) + + def test_typo_prevention_for_constructor(self): + """Ensure common constructor typos raise helpful errors.""" + with self.assertRaises(TypeError): + ChromaTransformer2DModel(num_attention_head=2) # Missing trailing 's' + + invalid_kwargs = self.default_config.copy() + invalid_kwargs["joint_attn_dim"] = invalid_kwargs.pop("joint_attention_dim") + with self.assertRaises(TypeError): + ChromaTransformer2DModel(**invalid_kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/transformer_base_test.py b/tests/utils/transformer_base_test.py index 18380452a..2d13a4bb5 100644 --- a/tests/utils/transformer_base_test.py +++ b/tests/utils/transformer_base_test.py @@ -134,9 +134,33 @@ def assert_tensor_dtype(self, tensor: torch.Tensor, expected_dtype: torch.dtype, self.assertEqual(tensor.dtype, expected_dtype, f"Expected dtype {expected_dtype}, got {tensor.dtype}. {msg}") def assert_tensor_device(self, tensor: torch.Tensor, expected_device: str, msg: str = ""): - """Assert tensor is on expected device.""" - self.assertEqual( - str(tensor.device), expected_device, f"Expected device {expected_device}, got {str(tensor.device)}. {msg}" + """Assert tensor is on expected device. Accepts CUDA/ROCm `cuda` vs `cuda:0` equivalence.""" + + def _parse_device(device_str: str) -> Tuple[str, Optional[int]]: + if ":" not in device_str: + return device_str, None + backend, index_str = device_str.split(":", 1) + try: + return backend, int(index_str) + except ValueError: + return backend, None + + actual_device = str(tensor.device) + expected_backend, expected_index = _parse_device(expected_device) + actual_backend, actual_index = _parse_device(actual_device) + + devices_match = False + if expected_backend == actual_backend: + if expected_backend == "cuda": + exp_idx = 0 if expected_index is None else expected_index + act_idx = 0 if actual_index is None else actual_index + devices_match = exp_idx == act_idx + else: + devices_match = expected_index == actual_index + + self.assertTrue( + devices_match, + f"Expected device {expected_device}, got {actual_device}. {msg}", ) def assert_no_nan_or_inf(self, tensor: torch.Tensor, msg: str = ""):