Skip to content

Commit 16ae92d

Browse files
committed
feat(trainer): add adopt_snapshot_atomic to SnapshotManager
Atomically adopt a pre-written checkpoint directory as the latest snapshot without calling model.save_pretrained().
1 parent 5b50837 commit 16ae92d

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

grail/trainer/snapshot_manager.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,67 @@ def save_snapshot_atomic(
136136
logger.debug("Cleanup failed: %s", cleanup_exc)
137137
raise
138138

139+
def adopt_snapshot_atomic(
140+
self,
141+
source_dir: Path | str,
142+
metadata: dict[str, Any],
143+
) -> None:
144+
"""Atomically adopt an already-written checkpoint directory as the latest snapshot.
145+
146+
Unlike ``save_snapshot_atomic``, this does NOT call ``model.save_pretrained()``.
147+
It writes only metadata into *source_dir*, then performs the same atomic rename
148+
dance so that ``snapshots/latest/`` always points to a complete snapshot.
149+
150+
Use this when FSDP2's ``save_full_checkpoint()`` has already written model
151+
weights to *source_dir* and we just need the snapshot manager to adopt it.
152+
153+
Args:
154+
source_dir: Directory containing model weights (already written).
155+
metadata: Snapshot metadata (epoch, timestamp, metrics, etc.).
156+
"""
157+
source_dir = Path(source_dir)
158+
target_dir = self.snapshot_dir / "latest"
159+
160+
if not source_dir.exists():
161+
raise FileNotFoundError(f"Source directory does not exist: {source_dir}")
162+
163+
try:
164+
# Write metadata into the source directory
165+
metadata_path = source_dir / "snapshot_metadata.json"
166+
with open(metadata_path, "w", encoding="utf-8") as f:
167+
json.dump(metadata, f, indent=2)
168+
169+
# Fsync directory
170+
try:
171+
fd = os.open(source_dir, os.O_RDONLY)
172+
os.fsync(fd)
173+
os.close(fd)
174+
except (OSError, AttributeError):
175+
pass
176+
177+
# Atomic rename chain (same protocol as save_snapshot_atomic)
178+
backup_dir = self.snapshot_dir / f"latest.backup.{uuid.uuid4().hex[:8]}"
179+
180+
if target_dir.exists():
181+
target_dir.rename(backup_dir)
182+
183+
source_dir.rename(target_dir)
184+
185+
if backup_dir.exists():
186+
try:
187+
shutil.rmtree(backup_dir)
188+
except Exception as cleanup_exc:
189+
logger.warning("Failed to cleanup backup snapshot: %s", cleanup_exc)
190+
191+
# Set SNAPSHOT_READY marker
192+
self._snapshot_ready_marker.touch()
193+
194+
logger.info("Snapshot adopted atomically from %s to %s", source_dir, target_dir)
195+
196+
except Exception as exc:
197+
logger.error("Failed to adopt snapshot: %s", exc)
198+
raise
199+
139200
def check_snapshot_ready(self) -> bool:
140201
"""Check if new snapshot is available for upload.
141202

0 commit comments

Comments
 (0)