Skip to content
Open
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ docs = [
"pillow>=10.0.0",
"cairosvg>=2.7.1"
]
hf = ["huggingface-hub>=0.20.0"]
zarr_conversion = [
"fire>=0.5.0",
"numcodecs>=0.16.3",
Expand Down
5 changes: 5 additions & 0 deletions src/electrai/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from electrai.callbacks.hf_upload import HuggingFaceCallback

__all__ = ["HuggingFaceCallback"]
172 changes: 172 additions & 0 deletions src/electrai/callbacks/hf_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import json
import logging
import shutil
from pathlib import Path
from typing import TYPE_CHECKING

from lightning.pytorch.callbacks import Callback

if TYPE_CHECKING:
from types import SimpleNamespace

logger = logging.getLogger(__name__)

MANIFEST_FILENAME = "hf_upload_manifest.json"


class HuggingFaceCallback(Callback):
"""Tracks saved checkpoints for deferred upload to HuggingFace Hub.

On clusters without internet (e.g. Princeton Della), checkpoints are
queued in a JSON manifest and uploaded later via ``electrai hf-push``.
When ``hf.upload_immediate`` is True, uploads are attempted inline
(failures are logged but never crash training).
"""

def __init__(self, cfg: SimpleNamespace) -> None:
super().__init__()
hf = cfg.hf
self.repo_id: str = hf["repo_id"]
self.every_n_epochs: int = hf.get("upload_every_n_epochs", 5)
self.upload_immediate: bool = hf.get("upload_immediate", False)
self.ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints"))
self.manifest_path = self.ckpt_path / MANIFEST_FILENAME
self._manifest: list[dict] = []
self._load_existing_manifest()

def _load_existing_manifest(self) -> None:
if self.manifest_path.exists():
with self.manifest_path.open(encoding="utf-8") as f:
self._manifest = json.load(f)

def _save_manifest(self) -> None:
self.ckpt_path.mkdir(parents=True, exist_ok=True)
with self.manifest_path.open("w", encoding="utf-8") as f:
json.dump(self._manifest, f, indent=2)

def _queue_checkpoint(
self, ckpt_file: Path, epoch: int | None, *, path_in_repo: str | None = None
) -> None:
entry = {
"path": str(ckpt_file),
"path_in_repo": path_in_repo or ckpt_file.name,
"epoch": epoch,
"repo_id": self.repo_id,
"uploaded": False,
}
self._manifest.append(entry)
self._save_manifest()
logger.info("Queued checkpoint for HF upload: %s", ckpt_file.name)
Comment on lines +49 to +61

def on_validation_end(self, trainer, pl_module) -> None: # noqa: ARG002
if trainer.sanity_checking:
return
epoch = trainer.current_epoch
if (epoch + 1) % self.every_n_epochs != 0:
return
if trainer.global_rank != 0:
return

last_ckpt = self.ckpt_path / "last.ckpt"
if not last_ckpt.exists():
return

# Copy to a stable filename so later hf-push uploads the correct
# snapshot even after last.ckpt is overwritten by subsequent epochs.
stable_name = f"last_epoch{epoch + 1:03d}.ckpt"
stable_path = self.ckpt_path / stable_name
shutil.copy2(last_ckpt, stable_path)

self._queue_checkpoint(stable_path, epoch, path_in_repo=stable_name)

if self.upload_immediate:
_upload_single(self._manifest[-1])
if self._manifest[-1]["uploaded"]:
stable_path.unlink(missing_ok=True)
self._save_manifest()

def on_train_end(self, trainer, pl_module) -> None: # noqa: ARG002
if trainer.global_rank != 0:
return
# Queue best checkpoints that haven't been queued yet
queued_paths = {e["path"] for e in self._manifest}
had_immediate = False
for ckpt_file in self.ckpt_path.glob("ckpt_*.ckpt"):
if str(ckpt_file) not in queued_paths:
self._queue_checkpoint(ckpt_file, epoch=None)
if self.upload_immediate:
_upload_single(self._manifest[-1])
had_immediate = True
if had_immediate:
self._save_manifest()

pending = sum(1 for e in self._manifest if not e["uploaded"])
if pending:
logger.info(
"%d checkpoint(s) pending upload. "
"Run 'electrai hf-push --ckpt-path %s' from a node with "
"internet access.",
pending,
self.ckpt_path,
)


def _upload_single(entry: dict) -> None:
"""Attempt to upload a single checkpoint. Logs errors, never raises."""
path = Path(entry["path"])
try:
from huggingface_hub import upload_file

if not path.exists():
logger.warning("Checkpoint file not found, skipping: %s", path)
return
path_in_repo = entry.get("path_in_repo", path.name)
upload_file(
path_or_fileobj=str(path),
path_in_repo=path_in_repo,
repo_id=entry["repo_id"],
)
entry["uploaded"] = True
logger.info("Uploaded %s to %s", path.name, entry["repo_id"])
except Exception:
logger.warning(
"HF upload failed for %s (will retry with hf-push)",
path.name,
exc_info=True,
)


def hf_push(ckpt_path: str, *, clean: bool = False) -> None:
"""Upload pending checkpoints from a manifest file.

Run this from a login node or machine with internet access.
"""
ckpt_dir = Path(ckpt_path)
manifest_path = ckpt_dir / MANIFEST_FILENAME
if not manifest_path.exists():
raise SystemExit(f"No manifest found at {manifest_path}")

with manifest_path.open(encoding="utf-8") as f:
manifest = json.load(f)

pending = [e for e in manifest if not e["uploaded"]]
if not pending:
logger.info("All checkpoints already uploaded.")
return

logger.info("Uploading %d pending checkpoint(s)...", len(pending))
for entry in pending:
_upload_single(entry)
Comment on lines +147 to +167
if clean and entry["uploaded"]:
Path(entry["path"]).unlink(missing_ok=True)

with manifest_path.open("w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2)

still_pending = sum(1 for e in manifest if not e["uploaded"])
if still_pending:
logger.warning("%d checkpoint(s) still failed to upload.", still_pending)
else:
logger.info("All checkpoints uploaded successfully.")
6 changes: 6 additions & 0 deletions src/electrai/configs/MP/config_resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ wb_pname: mp-experiment
# checkpoints
ckpt_path: ./checkpoints

# HuggingFace Hub (optional — install with `uv sync --extra hf`)
# hf:
# repo_id: your-username/your-repo # must already exist on HF
# upload_every_n_epochs: 5
# upload_immediate: false # set true on nodes with internet access

# test the model
# save_pred: true
# log_dir: ./logs
Expand Down
6 changes: 6 additions & 0 deletions src/electrai/configs/MP/config_resunet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ wb_pname: mp-experiment
# checkpoints
ckpt_path: ./checkpoints

# HuggingFace Hub (optional — install with `uv sync --extra hf`)
# hf:
# repo_id: your-username/your-repo # must already exist on HF
# upload_every_n_epochs: 5
# upload_immediate: false # set true on nodes with internet access

# test the model
# save_pred: true
# log_dir: ./logs
Expand Down
20 changes: 18 additions & 2 deletions src/electrai/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import argparse
import logging

import torch

Expand All @@ -23,6 +24,7 @@ def main() -> None:
RuntimeError
if no command was input
"""
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Electrai Entry Point")
subparsers = parser.add_subparsers(dest="command", required=True)

Expand All @@ -32,14 +34,28 @@ def main() -> None:
test_parser = subparsers.add_parser("test", help="Evaluate the model")
test_parser.add_argument("--config", type=str, required=True)

hf_push_parser = subparsers.add_parser(
"hf-push", help="Upload pending checkpoints to HuggingFace Hub"
)
hf_push_parser.add_argument(
"--ckpt-path", type=str, required=True, help="Path to checkpoint directory"
)
hf_push_parser.add_argument(
"--clean",
action="store_true",
help="Delete local checkpoint files after successful upload (includes best-model checkpoints)",
)

args = parser.parse_args()

if args.command == "train":
train(args)
elif args.command == "test":
test(args)
else:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because argparse with required=True subparsers already handles unknown commands

raise ValueError(f"Unknown command: {args.command}")
elif args.command == "hf-push":
from electrai.callbacks.hf_upload import hf_push

hf_push(args.ckpt_path, clean=args.clean)


if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion src/electrai/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def train(args):

lr_monitor = LearningRateMonitor(logging_interval="epoch")

callbacks = [checkpoint_cb, lr_monitor]

hf_cfg = getattr(cfg, "hf", None)
if hf_cfg and hf_cfg.get("repo_id"):
from electrai.callbacks.hf_upload import HuggingFaceCallback

callbacks.append(HuggingFaceCallback(cfg))

# -----------------------------
# Trainer
# -----------------------------
Expand All @@ -69,7 +77,7 @@ def train(args):
trainer = Trainer(
max_epochs=int(cfg.epochs),
logger=wandb_logger,
callbacks=[checkpoint_cb, lr_monitor],
callbacks=callbacks,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
precision=cfg.precision,
devices="auto",
Expand Down
Loading