Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,25 @@ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py

### Post-training

Coming soon!
The below commands creates the `cosmos-transfer` conda environment and installs the dependencies for post-training. This is the same as required for inference but with an additional package `apex` for training with bfloat16.
```bash
# Create the cosmos-transfer1 conda environment.
conda env create --file cosmos-transfer1.yaml
# Activate the cosmos-transfer1 conda environment.
conda activate cosmos-transfer1
# Install the dependencies.
pip install -r requirements.txt
# Patch Transformer engine linking issues in conda environments.
ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/
ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10
# Install Transformer engine.
pip install transformer-engine[pytorch]==1.12.0
# Install Apex for full training with bfloat16.
git clone https://github.com/NVIDIA/apex
CUDA_HOME=$CONDA_PREFIX pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex
```

You can test the environment setup for post-training with
```bash
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py --training
```
3 changes: 0 additions & 3 deletions cosmos_transfer1/POST_TRAINING.md

This file was deleted.

127 changes: 127 additions & 0 deletions cosmos_transfer1/checkpointer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from abc import ABC, abstractmethod
from typing import Optional

import torch

from cosmos_transfer1.utils import callback
from cosmos_transfer1.utils.config import CheckpointConfig, JobConfig
from cosmos_transfer1.utils.easy_io import easy_io
from cosmos_transfer1.utils.model import Model


class AbstractCheckpointer(ABC):
"""The checkpointer class. Supports checkpoint saving/loading to local disk."""

def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup):
"""Constructor of the checkpointer.

Args:
config_checkpoint (CheckpointConfig): The config object for the checkpointer.
"""
self.config_checkpoint = config_checkpoint
# Set the callback functions.
self.callbacks = callbacks

# Set checkpoint directories for local paths
self._local_dirname = os.path.join(config_job.path_local, "checkpoints")

self.strict_resume = config_checkpoint.strict_resume
self.load_path = config_checkpoint.load_path or None
self.load_training_state = config_checkpoint.load_training_state
self.only_load_scheduler_state = config_checkpoint.only_load_scheduler_state
self.save_thread = None
self.verbose = config_checkpoint.verbose
self.keys_not_to_resume = config_checkpoint.keys_not_to_resume
self.broadcast_via_filesystem = config_checkpoint.broadcast_via_filesystem

@abstractmethod
def save(
self,
model: Model,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
grad_scaler: torch.amp.GradScaler,
iteration: int,
) -> None:
pass

@abstractmethod
def load(
self,
model: Model,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
grad_scaler: Optional[torch.amp.GradScaler] = None,
) -> int:
pass

@property
def save_bucket(self):
"""Get the bucket name for saving checkpoints."""
return None

@property
def load_bucket(self):
"""Get the bucket name for loading checkpoints."""
return None

@property
def save_dirname(self):
return self._local_dirname

@property
def load_dirname(self):
return self._local_dirname

def finalize(self) -> None:
"""Finalize the checkpointer."""
if self.save_thread:
self.save_thread.join()

def _read_latest_checkpoint_file(self) -> str | None:
"""Get the file name of the latest saved checkpoint. If it doesn't exist, return None.

Returns:
checkpoint_file (str | None): file name of the latest saved checkpoint.
"""
checkpoint_file = None
checkpoint_path = os.path.join(self.load_dirname, "latest_checkpoint.txt")
if easy_io.exists(checkpoint_path):
checkpoint_file = easy_io.load(checkpoint_path).strip()

return checkpoint_file

def _write_latest_checkpoint_file(self, checkpoint_file: str) -> None:
"""Track the file name of the latest saved checkpoint.

Args:
checkpoint_file (str): file name of the latest saved checkpoint.
"""
content = f"{checkpoint_file}\n"
checkpoint_path = os.path.join(self.save_dirname, "latest_checkpoint.txt")
easy_io.dump(content, checkpoint_path)

def _check_checkpoint_exists(self, checkpoint_path: str) -> None:
"""If the file checkpoint_path does not exist, raise an error.

Args:
checkpoint_path (str): full path to the checkpoint.
"""
if not easy_io.exists(checkpoint_path):
raise FileNotFoundError(f"File not found: {checkpoint_path}")
Loading