Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions simpletuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
UnsupportedFieldAttributeWarning = None

if UnsupportedFieldAttributeWarning is not None:
warnings.filterwarnings(
"ignore",
message=r".*'repr' attribute.*",
category=UnsupportedFieldAttributeWarning,
)
warnings.filterwarnings(
"ignore",
message=r".*'frozen' attribute.*",
category=UnsupportedFieldAttributeWarning,
)
warnings.filterwarnings("ignore", category=UnsupportedFieldAttributeWarning)

# Filter out websockets deprecation warning about ws_handler second argument
Expand All @@ -34,15 +44,37 @@
def _suppress_swigvarlink(message, *args, **kwargs):
text = str(message)
category = kwargs.get("category", DeprecationWarning)
if not kwargs and args:
category = args[0]

# Suppress all UnsupportedFieldAttributeWarning instances
if UnsupportedFieldAttributeWarning is not None:
try:
if issubclass(category, UnsupportedFieldAttributeWarning):
return None
except TypeError:
if category is UnsupportedFieldAttributeWarning:
return None
# Also check if the message is an instance of UnsupportedFieldAttributeWarning
if isinstance(message, UnsupportedFieldAttributeWarning):
return None
# Check the message type directly
if type(message).__name__ == "UnsupportedFieldAttributeWarning":
return None

# Check message text for pydantic field warnings
if "attribute" in text and "Field()" in text:
return None

if "swigvarlink" in text and category is DeprecationWarning:
return None
if "MPS autocast" in text:
return None
if "UnsupportedFieldAttributeWarning" in text:
return None

return _original_warn(message, *args, **kwargs)


warnings.warn = _suppress_swigvarlink


__version__ = "3.0.2"
2 changes: 1 addition & 1 deletion simpletuner/helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def parse_cmdline_args(input_args=None, exit_on_error: bool = False):
logger.error(f"Could not load controlnet_custom_config: {e}")
raise
if args.webhook_config is not None:
print(f"DEBUG: webhook_config at start = {args.webhook_config} (type: {type(args.webhook_config)})")
logger.debug("webhook_config at start = %s (type: %s)", args.webhook_config, type(args.webhook_config))
# Handle different types of webhook_config
# First, check if it's an AST object using isinstance (the proper way)
if isinstance(args.webhook_config, (ast.AST, ast.Name, ast.Call, ast.Dict, ast.List, ast.Constant)):
Expand Down
17 changes: 13 additions & 4 deletions simpletuner/helpers/data_backend/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def read_image(self, filepath: str, delete_problematic_images: bool = False) ->
try:
image = file_loader(filepath)
return image
except FileNotFoundError:
log_level = logging.WARNING if should_log() else logging.DEBUG
logger.log(log_level, f"Image not found: {filepath}")
return None
except Exception as e:
logger.error(f"Encountered error opening image {filepath}: {e}", exc_info=True)
if delete_problematic_images:
Expand All @@ -225,7 +229,8 @@ def read_image_batch(self, filepaths: List[str], delete_problematic_images: bool
try:
image_data = self.read_image(filepath, delete_problematic_images)
if image_data is None:
logger.warning(f"Unable to load image '{filepath}', skipping.")
log_level = logging.WARNING if should_log() else logging.DEBUG
logger.log(log_level, f"Unable to load image '{filepath}', skipping.")
continue
output_images.append(image_data)
available_keys.append(filepath)
Expand All @@ -237,9 +242,13 @@ def read_image_batch(self, filepaths: List[str], delete_problematic_images: bool
except Exception as del_e:
logger.error(f"Failed to delete problematic image {filepath}: {del_e}")
else:
logger.warning(
f"A problematic image {filepath} is detected, but we are not allowed to remove it, because --delete_problematic_images is not provided."
f" Please correct this manually. Error: {e}"
log_level = logging.WARNING if should_log() else logging.DEBUG
logger.log(
log_level,
(
f"A problematic image {filepath} is detected, but we are not allowed to remove it, because "
f"--delete_problematic_images is not provided. Please correct this manually. Error: {e}"
),
)
return available_keys, output_images

Expand Down
8 changes: 7 additions & 1 deletion simpletuner/helpers/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,13 +436,19 @@ def get_logger(
Retrieve a WebhookLogger with consistent configuration across the project.
"""
logger = logging.getLogger(name)
propagate_value = propagate if propagate is not None else False
if isinstance(logger, WebhookLogger):
logger.configure(
env_var=env_var,
default_level=default_level,
disable_webhook=disable_webhook,
propagate=propagate,
propagate=propagate_value,
)
else:
if propagate_value is not None:
logger.propagate = propagate_value
if not logger.handlers:
logger.addHandler(logging.NullHandler())
return logger # type: ignore[return-value]


Expand Down
17 changes: 13 additions & 4 deletions simpletuner/helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging
import os
from functools import lru_cache

import accelerate
import torch
from accelerate.logging import get_logger

logger = get_logger(__name__)
logger = logging.getLogger(__name__)
from simpletuner.helpers.training.multi_process import should_log

if should_log():
Expand Down Expand Up @@ -52,7 +52,12 @@
is_bitsandbytes_available = True
except:
if torch.cuda.is_available():
print("Could not load bitsandbytes library. BnB-specific optimisers and other functionality will be unavailable.")
# Avoid requiring Accelerate logging state during import
level = logging.WARNING if should_log() else logging.DEBUG
logging.getLogger(__name__).log(
level,
"Could not load bitsandbytes library. BnB-specific optimisers and other functionality will be unavailable.",
)

# Some optimizers are not available in multibackend bitsandbytes as of January 2025.
is_ademamix_available = False
Expand All @@ -67,7 +72,11 @@
is_prodigy_available = True
except:
if torch.cuda.is_available():
print("Could not load prodigyplus library. Prodigy will not be available.")
log_level = logging.WARNING if should_log() else logging.DEBUG
logger.log(
log_level,
"Could not load prodigyplus library. Prodigy will not be available.",
)


optimizer_choices = {
Expand Down
12 changes: 8 additions & 4 deletions simpletuner/simpletuner_sdk/process_keeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,15 @@ def __init__(self, config):
send_event("state", state_event)
raise
except Exception as e:
logger.error(f"Function error: {{e}}")
import traceback
traceback_str = traceback.format_exc()
logger.error(traceback_str)
error_message = str(e)
log_level = logging.ERROR
if "Training configuration could not be parsed" in error_message:
log_level = logging.INFO
logger.log(log_level, f"Function error: {{error_message}}")
import traceback
traceback_str = traceback.format_exc() if log_level == logging.ERROR else ""
if log_level == logging.ERROR:
logger.error(traceback_str)
send_event("error", {{"message": error_message, "traceback": traceback_str}})
send_event("state", {{"status": "failed", "message": error_message}})

Expand Down
26 changes: 13 additions & 13 deletions simpletuner/simpletuner_sdk/server/routes/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,22 +233,22 @@ def _resolve_datasets_dir_and_validate_path(
"""
# Load resolved defaults from WebUIStateStore (includes fallbacks)
webui_state = WebUIStateStore()
defaults_raw = webui_state.load_defaults()
defaults_bundle = webui_state.resolve_defaults(defaults_raw)
defaults_bundle = webui_state.get_defaults_bundle()
resolved = defaults_bundle["resolved"]

# Also check onboarding values as they might have been set but not yet applied to defaults
# Onboarding values take precedence when present. The onboarding step is named
# "default_datasets_dir" in the flow; fall back to defaults otherwise.
onboarding = webui_state.load_onboarding()
datasets_dir = resolved.get("datasets_dir")

# If datasets_dir is still the fallback, check if there's an onboarding value
# Use correct step ID: "default_datasets_dir" (not "datasets_dir")
if datasets_dir == defaults_bundle["fallbacks"].get("datasets_dir"):
datasets_step = onboarding.steps.get("default_datasets_dir")
if datasets_step and datasets_step.value:
datasets_dir = datasets_step.value

allow_outside = resolved.get("allow_dataset_paths_outside_dir", False)
datasets_dir = None
onboarding_step = onboarding.steps.get("default_datasets_dir")
if onboarding_step and onboarding_step.value:
datasets_dir = onboarding_step.value
else:
datasets_dir = resolved.get("datasets_dir")
if not datasets_dir:
datasets_dir = defaults_bundle["fallbacks"].get("datasets_dir")

allow_outside = bool(resolved.get("allow_dataset_paths_outside_dir", False))

# Use provided path or fall back to resolved datasets_dir (which includes fallback)
if path is None:
Expand Down
17 changes: 7 additions & 10 deletions tests/test_transformer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,17 @@ def test_typo_prevention_utilities(self):
"""Test that typo prevention utilities work correctly."""
typo_utils = TypoTestUtils()

# Create a mock model for testing
mock_model = Mock()
mock_model.forward = Mock(return_value="success")
class _SampleModel:
@staticmethod
def forward(*, input):
return f"processed-{input}"

model = _SampleModel()

# Test parameter name typo detection
valid_params = {"input": "test"}
typo_params = {"inpt": "input"} # typo

try:
typo_utils.test_parameter_name_typos(mock_model, "forward", valid_params, typo_params)
# Should pass - the utility should detect the typo
except Exception as e:
# This is expected if the mock doesn't handle the test properly
self.skipTest(f"Typo test requires specific mock setup: {e}")
typo_utils.test_parameter_name_typos(model, "forward", valid_params, typo_params)

def test_performance_utilities(self):
"""Test that performance utilities work correctly."""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_transformers/test_qwen_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,28 @@ def test_edge_cases(self):
class TestQwenImageTransformer2DModel(TransformerBaseTest):
"""Test QwenImageTransformer2DModel class."""

@classmethod
def setUpClass(cls):
super().setUpClass()
try:
QwenImageTransformer2DModel(
patch_size=1,
in_channels=4,
out_channels=4,
num_layers=1,
attention_head_dim=8,
num_attention_heads=2,
joint_attention_dim=16,
)
except RuntimeError as exc:
if "register_for_config" in str(exc):
raise unittest.SkipTest(
"QwenImageTransformer2DModel is unavailable on this diffusers build; skipping Qwen transformer tests."
) from exc
raise
except Exception as exc: # pragma: no cover - safety net for optional deps
raise unittest.SkipTest(f"Qwen transformer tests require additional dependencies: {exc}") from exc

def setUp(self):
super().setUp()
self.patch_size = 2
Expand Down