Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import shutil
import subprocess
import sys
from pathlib import Path

from setuptools import find_packages, setup

Expand Down Expand Up @@ -155,6 +156,25 @@ def get_platform_dependencies():
return deps


def _collect_package_files(*directories: str):
"""Collect package data files relative to the simpletuner package."""
collected = []
package_root = Path("simpletuner")
for directory in directories:
root = Path(directory)
if not root.exists():
continue
for path in root.rglob("*"):
if path.is_file():
try:
relative = path.relative_to(package_root)
except ValueError:
# Skip files outside package root
continue
collected.append(str(relative))
return collected


# Base dependencies (minimal, works on all platforms)
base_deps = [
"diffusers>=0.35.1",
Expand Down Expand Up @@ -236,6 +256,15 @@ def get_platform_dependencies():
author="bghira",
# license handled by pyproject.toml
packages=find_packages(),
include_package_data=True,
package_data={
"simpletuner": _collect_package_files(
"simpletuner/templates",
"simpletuner/static",
"simpletuner/config",
"simpletuner/documentation",
),
},
python_requires=">=3.11,<3.14",
install_requires=base_deps + platform_deps_for_install,
extras_require=extras_require,
Expand Down
11 changes: 6 additions & 5 deletions simpletuner/helpers/models/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,13 @@ def model_predict(self, prepared_batch):
return_dict=False,
)[0]

# unpack noise prediction
noise_pred = pipeline_class._unpack_latents(noise_pred, pixel_height, pixel_width, self.vae_scale_factor)
# unpack noise prediction if the transformer returned packed tokens
if noise_pred.dim() == 3:
noise_pred = pipeline_class._unpack_latents(noise_pred, pixel_height, pixel_width, self.vae_scale_factor)

# remove extra dimension from _unpack_latents
if noise_pred.dim() == 5:
noise_pred = noise_pred.squeeze(2) # Remove the frame dimension
# remove extra dimension from _unpack_latents
if noise_pred.dim() == 5:
noise_pred = noise_pred.squeeze(2) # Remove the frame dimension

return {"model_prediction": noise_pred}

Expand Down
57 changes: 29 additions & 28 deletions simpletuner/helpers/training/quantisation/quanto_workarounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,38 @@
if torch.cuda.is_available():
# the marlin fp8 kernel needs some help with dtype casting for some reason
# see: https://github.com/huggingface/optimum-quanto/pull/296#issuecomment-2380719201
from optimum.quanto.library.extensions.cuda import ext as quanto_ext
if torch.device("cuda").type == "cuda" and torch.version.cuda:
from optimum.quanto.library.extensions.cuda import ext as quanto_ext

# Save the original operator
original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin
# Save the original operator
original_gemm_f16f8_marlin = torch.ops.quanto.gemm_f16f8_marlin

def fp8_marlin_gemm_wrapper(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
# Ensure 'a' has the correct dtype
a = a.to(b_scales.dtype)
# Call the original operator
return original_gemm_f16f8_marlin(
a,
b_q_weight,
b_scales,
workspace,
num_bits,
size_m,
size_n,
size_k,
)
def fp8_marlin_gemm_wrapper(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
size_m: int,
size_n: int,
size_k: int,
) -> torch.Tensor:
# Ensure 'a' has the correct dtype
a = a.to(b_scales.dtype)
# Call the original operator
return original_gemm_f16f8_marlin(
a,
b_q_weight,
b_scales,
workspace,
num_bits,
size_m,
size_n,
size_k,
)

# Monkey-patch the operator
torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper
# Monkey-patch the operator
torch.ops.quanto.gemm_f16f8_marlin = fp8_marlin_gemm_wrapper

class TinyGemmQBitsLinearFunction(optimum.quanto.tensor.function.QuantizedLinearFunction):
@staticmethod
Expand Down
101 changes: 91 additions & 10 deletions simpletuner/simpletuner_sdk/server/services/system_status_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,19 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]:
devices = (inventory or {}).get("devices") or []
results: List[Dict[str, Any]] = []
mac_utilisation: Optional[List[Optional[float]]] = None
nvidia_fallback: Optional[List[Optional[float]]] = None
mac_memory: Optional[List[Optional[float]]] = None
nvidia_fallback: Optional[List[Dict[str, Optional[float]]]] = None
rocm_fallback: Optional[List[Optional[float]]] = None

if backend == "mps":
mac_utilisation = self._get_macos_gpu_utilisation()
mac_memory = self._get_mps_memory_percent()

for position, device in enumerate(devices):
index = device.get("index")
name = device.get("name") or f"GPU {index if index is not None else '?'}"
utilisation: Optional[float] = None
memory_percent: Optional[float] = None

if backend == "cuda" and index is not None and torch is not None and hasattr(torch.cuda, "utilization"):
try:
Expand All @@ -95,6 +98,7 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]:
except Exception:
logger.debug("Failed to read CUDA utilisation for device %s", index, exc_info=True)
utilisation = None
memory_percent = self._get_cuda_memory_percent(index)
if utilisation is None and backend == "mps" and mac_utilisation:
target_idx: Optional[int] = None
if isinstance(index, int) and 0 <= index < len(mac_utilisation):
Expand All @@ -103,16 +107,29 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]:
target_idx = position
if target_idx is not None:
utilisation = mac_utilisation[target_idx]
if mac_memory:
mem_idx: Optional[int] = None
if isinstance(index, int) and 0 <= index < len(mac_memory):
mem_idx = index
elif 0 <= position < len(mac_memory):
mem_idx = position
if mem_idx is not None:
memory_percent = mac_memory[mem_idx]
if utilisation is None and backend == "cuda":
if nvidia_fallback is None:
nvidia_fallback = self._get_nvidia_gpu_utilisation()
nvidia_fallback = self._get_nvidia_gpu_stats()
if nvidia_fallback:
target_idx = None
target_idx: Optional[int] = None
if isinstance(index, int) and 0 <= index < len(nvidia_fallback):
target_idx = index
elif 0 <= position < len(nvidia_fallback):
target_idx = position
if target_idx is not None:
fallback_entry = nvidia_fallback[target_idx]
if utilisation is None:
utilisation = fallback_entry.get("utilization_percent")
if memory_percent is None:
memory_percent = fallback_entry.get("memory_percent")
utilisation = nvidia_fallback[target_idx]
if utilisation is None and backend == "rocm":
if rocm_fallback is None:
Expand All @@ -132,6 +149,7 @@ def _get_gpu_utilisation(self) -> List[Dict[str, Any]]:
"name": name,
"backend": backend,
"utilization_percent": round(utilisation, 1) if utilisation is not None else None,
"memory_percent": round(memory_percent, 1) if memory_percent is not None else None,
}
)

Expand Down Expand Up @@ -442,7 +460,7 @@ def _get_nvidia_gpu_utilisation(self) -> Optional[List[Optional[float]]]:
completed = subprocess.run(
[
"nvidia-smi",
"--query-gpu=utilization.gpu",
"--query-gpu=utilization.gpu,memory.used,memory.total",
"--format=csv,noheader,nounits",
],
check=True,
Expand All @@ -458,19 +476,82 @@ def _get_nvidia_gpu_utilisation(self) -> Optional[List[Optional[float]]]:
if not lines:
return None

utilisation_values: List[Optional[float]] = []
stats: List[Dict[str, Optional[float]]] = []
for line in lines:
text = line.strip()
if not text:
utilisation_values.append(None)
stats.append({"utilization_percent": None, "memory_percent": None})
continue
parts = [part.strip() for part in text.split(",")]
if len(parts) < 3:
logger.debug("Discarding unexpected nvidia-smi output: %s", text)
stats.append({"utilization_percent": None, "memory_percent": None})
continue
util_raw, mem_used_raw, mem_total_raw = parts[:3]
util_val: Optional[float]
mem_percent: Optional[float]
try:
utilisation_values.append(round(float(text), 1))
util_val = round(float(util_raw), 1)
except ValueError:
logger.debug("Discarding unexpected nvidia-smi output: %s", text)
utilisation_values.append(None)
util_val = None
try:
mem_used = float(mem_used_raw)
mem_total = float(mem_total_raw)
if mem_total > 0:
mem_percent = round((mem_used / mem_total) * 100.0, 1)
else:
mem_percent = None
except ValueError:
mem_percent = None
stats.append({"utilization_percent": util_val, "memory_percent": mem_percent})

return utilisation_values or None
return stats or None

def _get_cuda_memory_percent(self, index: int) -> Optional[float]:
if torch is None or not torch.cuda.is_available(): # type: ignore[attr-defined]
return None
if not hasattr(torch.cuda, "mem_get_info"):
return None
try:
try:
free_bytes, total_bytes = torch.cuda.mem_get_info(index) # type: ignore[misc]
except TypeError:
with torch.cuda.device(index):
free_bytes, total_bytes = torch.cuda.mem_get_info() # type: ignore[call-arg]
except Exception:
logger.debug("Failed to read CUDA memory info for device %s", index, exc_info=True)
return None
if not total_bytes:
return None
used = total_bytes - free_bytes
if used < 0:
used = 0
try:
percent = (used / total_bytes) * 100.0
except Exception:
return None
return round(float(percent), 1)

def _get_mps_memory_percent(self) -> Optional[List[Optional[float]]]:
if torch is None:
return None
backend = getattr(torch.backends, "mps", None)
if backend is None or not backend.is_available():
return None
driver_alloc = getattr(torch.mps, "driver_allocated_memory", None)
driver_total = getattr(torch.mps, "driver_total_memory", None)
if not callable(driver_alloc) or not callable(driver_total):
return None
try:
allocated = float(driver_alloc())
total = float(driver_total())
except Exception:
logger.debug("Unable to query MPS memory statistics", exc_info=True)
return None
if total <= 0:
return None
percent = round((allocated / total) * 100.0, 1)
return [percent]


__all__ = ["SystemStatusService"]
31 changes: 29 additions & 2 deletions simpletuner/templates/trainer_htmx.html
Original file line number Diff line number Diff line change
Expand Up @@ -2049,11 +2049,16 @@
typeof gpu?.utilization_percent === 'number' && Number.isFinite(gpu.utilization_percent)
? Number(gpu.utilization_percent)
: null;
const memoryPercent =
typeof gpu?.memory_percent === 'number' && Number.isFinite(gpu.memory_percent)
? Number(gpu.memory_percent)
: null;
return {
index: idx,
name,
backend: gpu?.backend || data.backend || null,
utilization_percent: utilization,
memory_percent: memoryPercent,
};
})
: [];
Expand Down Expand Up @@ -2129,16 +2134,38 @@
const gpus = Array.isArray(this.systemStatus.gpus) ? this.systemStatus.gpus : [];
gpus.forEach((gpu, index) => {
const util = gpu.utilization_percent;
const mem = gpu.memory_percent;
const trimmedName = typeof gpu.name === 'string' ? gpu.name.trim() : '';
const displayLabel =
trimmedName ||
(typeof gpu.index === 'number' ? `GPU${gpu.index}` : `GPU${index}`);
const valueParts = [];
if (util !== null) {
valueParts.push(`${Math.round(util)}%`);
}
if (mem !== null) {
valueParts.push(`${Math.round(mem)}% mem`);
}
const formattedDisplay = valueParts.length > 0 ? valueParts.join(' · ') : '—';
const tooltipParts = [];
if (trimmedName) {
tooltipParts.push(trimmedName);
} else {
tooltipParts.push(displayLabel);
}
if (util !== null) {
tooltipParts.push(`Util: ${util.toFixed(1)}%`);
}
if (mem !== null) {
tooltipParts.push(`Mem: ${mem.toFixed(1)}%`);
}
items.push({
key: `gpu-${gpu.index ?? index}`,
label: displayLabel,
value: util,
formatted: util !== null ? `${Math.round(util)}%` : '—',
tooltip: trimmedName || displayLabel,
memory: mem,
formatted: formattedDisplay,
tooltip: tooltipParts.join(' • '),
});
});
const offloadInfo = this.systemStatus.deepspeed_offload;
Expand Down
3 changes: 2 additions & 1 deletion tests/test_transformer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_all_test_files_discovered(self):
"hidream",
"auraflow",
"chroma",
"chroma_controlnet",
"cosmos",
"sd3",
"pixart",
Expand All @@ -95,7 +96,7 @@ def test_all_test_files_discovered(self):
self.assertEqual(
len(self.test_files),
len(expected_transformers),
f"Expected {len(expected_transformers)} test files, found {len(self.test_files)}",
f"Expected {len(expected_transformers)} test files, found {len(self.test_files)}: {self.test_files}",
)

def test_base_test_class_inheritance(self):
Expand Down
Loading
Loading