Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include simpletuner *.html *.css *.js *.ico *.json *.md
2 changes: 1 addition & 1 deletion simpletuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def _suppress_swigvarlink(message, *args, **kwargs):
warnings.warn = _suppress_swigvarlink


__version__ = "3.0.4"
__version__ = "3.0.5"
49 changes: 48 additions & 1 deletion simpletuner/helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from typing import Any, Dict, List, Optional

import huggingface_hub
import wandb

import wandb
from simpletuner.helpers import log_format # noqa
from simpletuner.helpers.caching.memory import reclaim_memory
from simpletuner.helpers.configuration.cli_utils import mapping_to_cli_args
Expand Down Expand Up @@ -49,6 +49,7 @@
create_optimizer_with_param_groups,
determine_optimizer_class_with_config,
determine_params_to_optimize,
is_bitsandbytes_available,
is_lr_schedulefree,
is_lr_scheduler_disabled,
)
Expand Down Expand Up @@ -238,6 +239,49 @@ def _update_grad_metrics(
) and not self.config.use_deepspeed_optimizer:
target_logs["grad_absmax"] = self.grad_norm

def _config_uses_bitsandbytes(self) -> bool:
if not getattr(self, "config", None):
return False

def _contains_bnb(value: object) -> bool:
if isinstance(value, str):
return "bnb" in value.lower()
if isinstance(value, dict):
return any(_contains_bnb(item) for item in value.values())
if isinstance(value, (list, tuple, set)):
return any(_contains_bnb(item) for item in value)
return False

for attr_value in vars(self.config).values():
try:
if _contains_bnb(attr_value):
return True
except Exception:
continue
return False

def _enable_dynamo_dynamic_output_capture(self) -> None:
try:
import torch._dynamo as torch_dynamo
except Exception as exc:
logger.warning("Unable to configure Torch Dynamo dynamic output capture: %s", exc)
return

config_obj = getattr(torch_dynamo, "config", None)
if config_obj is None:
logger.debug("Torch Dynamo config unavailable; skipping dynamic output capture configuration.")
return
if not hasattr(config_obj, "capture_dynamic_output_shape_ops"):
logger.debug(
"Torch Dynamo config lacks capture_dynamic_output_shape_ops; skipping dynamic output capture configuration."
)
return
if getattr(config_obj, "capture_dynamic_output_shape_ops", False):
return

config_obj.capture_dynamic_output_shape_ops = True
logger.info("Torch Dynamo capture_dynamic_output_shape_ops enabled for bitsandbytes models.")

def parse_arguments(self, args=None, disable_accelerator: bool = False, exit_on_error: bool = False):
skip_config_fallback = False
args_payload = args
Expand Down Expand Up @@ -402,6 +446,9 @@ def _coerce_flag(value: object) -> bool:

dynamo_plugin = None
if resolved_dynamo_backend and resolved_dynamo_backend != DynamoBackend.NO:
if is_bitsandbytes_available and self._config_uses_bitsandbytes():
self._enable_dynamo_dynamic_output_capture()

plugin_kwargs: Dict[str, object] = {"backend": resolved_dynamo_backend}

mode_value = getattr(self.config, "dynamo_mode", None)
Expand Down
2 changes: 1 addition & 1 deletion simpletuner/helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import diffusers
import numpy as np
import torch
import wandb
from tqdm import tqdm

import wandb
from simpletuner.helpers.models.common import ImageModelFoundation, ModelFoundation, VideoModelFoundation
from simpletuner.helpers.training.wrappers import unwrap_model

Expand Down
20 changes: 8 additions & 12 deletions simpletuner/simpletuner_sdk/server/dependencies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,10 @@ def _load_active_config_cached() -> Dict[str, Any]:
logger.error(f"Invalid config name contains path separator: {active_config}")
return {}

# Load config file - active_config is a string (config name)
config_path = Path(config_store.config_dir) / active_config / "config.json"
if not config_path.exists():
logger.warning(f"Active config file not found: {config_path}")
return {}

try:
with open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
data, _metadata = config_store.load_config(active_config)

# For backward compatibility, merge nested "config" sections if present.
if isinstance(data, dict):
config_section = data.get("config")
if isinstance(config_section, dict):
Expand All @@ -76,10 +70,12 @@ def _load_active_config_cached() -> Dict[str, Any]:
continue
merged.setdefault(key, value)
return merged

return data
except Exception as e:
logger.error(f"Error loading config: {e}")
return data if isinstance(data, dict) else {}
except FileNotFoundError as exc:
logger.warning(f"Active config '{active_config}' not found: {exc}")
return {}
except Exception as exc:
logger.error(f"Error loading config '{active_config}': {exc}")
return {}


Expand Down
56 changes: 52 additions & 4 deletions simpletuner/simpletuner_sdk/server/services/config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,9 @@ def _extract_backend_path(config_obj: Any) -> Optional[str]:
# But exclude files that are clearly dataloader configs
if self.config_dir.exists():
for config_file in self.config_dir.glob("*.json"):
if config_file.name.endswith(".metadata.json"):
# Skip metadata sidecar files to avoid duplicate pseudo-config entries
continue
try:
with config_file.open("r", encoding="utf-8") as f:
data = json.load(f)
Expand Down Expand Up @@ -1155,19 +1158,64 @@ def copy_config(self, source: str, target: str) -> ConfigMetadata:
if not source_path.exists():
raise FileNotFoundError(f"Configuration '{source}' not found")

if target_path.exists():
if target_path.exists() or self._is_folder_config(target):
raise FileExistsError(f"Configuration '{target}' already exists")

# Load source and create new metadata
# Load source to obtain config payload and metadata template
config, old_metadata = self.load_config(source)
new_metadata = self._create_metadata(target, f"Copy of {old_metadata.description or source}")
new_metadata.model_family = old_metadata.model_family
new_metadata.model_type = old_metadata.model_type
new_metadata.tags = old_metadata.tags.copy()
new_metadata.parent_template = old_metadata.parent_template

# Save as new config
self.save_config(target, config, new_metadata)
def _translate_relative_paths(payload: Any) -> Any:
if isinstance(payload, dict):
return {key: _translate_relative_paths(value) for key, value in payload.items()}
if isinstance(payload, list):
return [_translate_relative_paths(item) for item in payload]
if isinstance(payload, str):
prefixes = ("", "./", ".\\")
for prefix in prefixes:
candidate = f"{prefix}{source}"
replacement = f"{prefix}{target}"

# Exact match (e.g. 'env' or './env')
if payload == candidate:
return replacement

# Match JSON file names (env.json)
candidate_json = f"{candidate}.json"
if payload == candidate_json:
return f"{replacement}.json"

for separator in ("/", "\\"):
needle = candidate + separator
if payload.startswith(needle):
return replacement + payload[len(candidate) :]

return payload

return payload

if self._is_folder_config(source):
source_dir = source_path.parent
target_dir = self.config_dir / target

if target_dir.exists():
raise FileExistsError(f"Configuration '{target}' already exists")

shutil.copytree(source_dir, target_dir)

config = _translate_relative_paths(config)

# Rewrite config and metadata to update timestamps/name while preserving extras in the folder
self.save_config(target, config, new_metadata, overwrite=True)
else:
config = _translate_relative_paths(config)

# Save as new flat config
self.save_config(target, config, new_metadata)

return new_metadata

Expand Down
59 changes: 54 additions & 5 deletions simpletuner/templates/environments_tab.html
Original file line number Diff line number Diff line change
Expand Up @@ -388,13 +388,16 @@
await trainerStore.updateConfigSelectors();
}
}
window.dispatchEvent(new CustomEvent('configs-updated', {
detail: {
await this.announceConfigChange(
{
source: 'environment-create',
name: payload && payload.name ? payload.name : null,
name: payload && payload.name
? payload.name
: (environment && environment.name ? environment.name : null),
environment: environment,
}
}));
},
{ skipTrainerRefresh: true }
);
},

openCreateDataloaderModal(environmentName) {
Expand Down Expand Up @@ -476,6 +479,31 @@
}
},

async announceConfigChange(detail = {}, { skipTrainerRefresh = false } = {}) {
const eventDetail = {
source: `${this.configType}-manager`,
configType: this.configType,
...detail,
};

if (this.configType === 'model' && !skipTrainerRefresh) {
const trainerStore = Alpine.store('trainer');
if (trainerStore) {
try {
if (typeof trainerStore.loadEnvironmentConfigs === 'function') {
await trainerStore.loadEnvironmentConfigs();
} else if (typeof trainerStore.updateConfigSelectors === 'function') {
await trainerStore.updateConfigSelectors();
}
} catch (error) {
console.error('Failed to refresh trainer configs after update:', error);
}
}
}

window.dispatchEvent(new CustomEvent('configs-updated', { detail: eventDetail }));
},

get allDataloaderConfigs() {
const store = Alpine.store('dataloaderConfigs');
return store ? store.configs : [];
Expand Down Expand Up @@ -593,6 +621,10 @@
if (response.ok) {
await this.loadConfigs();
window.showToast('Deleted configuration: ' + name, 'success');
await this.announceConfigChange({
source: 'environment-delete',
name: sanitized,
});
} else {
const error = await response.json();
// Show more specific error for active config
Expand Down Expand Up @@ -649,6 +681,10 @@
if (response.ok) {
await this.loadConfigs();
window.showToast('Created duplicate: ' + targetName, 'success');
await this.announceConfigChange({
source: 'environment-duplicate',
name: targetName,
});
}
} catch (error) {
console.error('Failed to duplicate config:', error);
Expand Down Expand Up @@ -677,6 +713,11 @@
if (response.ok) {
await this.loadConfigs();
window.showToast('Renamed to ' + targetName, 'success');
await this.announceConfigChange({
source: 'environment-rename',
name: targetName,
previousName: sourceName,
});
}
} catch (error) {
console.error('Failed to rename config:', error);
Expand Down Expand Up @@ -729,6 +770,10 @@
if (response.ok) {
await this.loadConfigs();
window.showToast('Created configuration from template: ' + normalizedName, 'success');
await this.announceConfigChange({
source: 'environment-from-template',
name: normalizedName,
});
}
} catch (error) {
console.error('Failed to create from template:', error);
Expand Down Expand Up @@ -766,6 +811,10 @@
if (response.ok) {
await this.loadConfigs();
window.showToast('Imported configuration: ' + name, 'success');
await this.announceConfigChange({
source: 'environment-import',
name: name,
});
}
} catch (error) {
console.error('Failed to import config:', error);
Expand Down
30 changes: 30 additions & 0 deletions tests/test_config_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_config_store():
config_names = [c.get("name") for c in configs]
assert "default" in config_names
assert "test_config" in config_names
assert not any(isinstance(name, str) and name.endswith(".metadata") for name in config_names)
print(f" Found {len(configs)} configs: {', '.join(config_names)}")

# Test 3: Load configuration
Expand Down Expand Up @@ -72,6 +73,33 @@ def test_config_store():
assert copy_config["--model_type"] == config["--model_type"]
print(f" Created copy: {copy_metadata.name}")

# Test 5b: Copy folder-based configuration retains directory layout
folder_name = "folder_env"
folder_dir = Path(tmpdir) / folder_name
folder_dir.mkdir(parents=True, exist_ok=True)
# Seed additional assets that should be preserved during copy
(folder_dir / "multidatabackend.json").write_text("[]", encoding="utf-8")
folder_config = config.copy()
folder_config["--model_family"] = "wan"
folder_config["--model_type"] = "lora"
folder_config["--data_backend_config"] = f"{folder_name}/multidatabackend.json"
folder_config["data_backend_config"] = f"{folder_name}/multidatabackend.json"
store.save_config(folder_name, folder_config, overwrite=True)

folder_copy_metadata = store.copy_config(folder_name, "folder_copy")
assert folder_copy_metadata.name == "folder_copy"

copied_dir = Path(tmpdir) / "folder_copy"
assert copied_dir.is_dir()
assert (copied_dir / "config.json").exists()
assert (copied_dir / "multidatabackend.json").exists()
assert not (Path(tmpdir) / "folder_copy.metadata.json").exists()

copied_config, _ = store.load_config("folder_copy")
expected_path = "folder_copy/multidatabackend.json"
assert copied_config.get("data_backend_config") == expected_path
assert copied_config.get("--data_backend_config") == expected_path

# Test 6: Rename configuration
print("✓ Renaming configuration...")
rename_metadata = store.rename_config("test_copy", "test_renamed")
Expand All @@ -80,6 +108,7 @@ def test_config_store():
config_names = [c.get("name") for c in configs]
assert "test_renamed" in config_names
assert "test_copy" not in config_names
assert not any(isinstance(name, str) and name.endswith(".metadata") for name in config_names)
print(f" Renamed to: {rename_metadata.name}")

# Test 7: Set active configuration
Expand All @@ -104,6 +133,7 @@ def test_config_store():
configs = store.list_configs()
config_names = [c.get("name") for c in configs]
assert "test_renamed" not in config_names
assert not any(isinstance(name, str) and name.endswith(".metadata") for name in config_names)
print(f" Deleted test_renamed")

# Test 10: Invalid config validation
Expand Down
Loading